tensorflow: Speichern und Wiederherstellen der Sitzung

8

Ich versuche, einen Vorschlag aus den Antworten zu implementieren: Tensorflow: Wie wird ein Modell gespeichert / wiederhergestellt?

Ich habe ein Objekt, das ein tensorflow -Modell in einem sklearn -Stil umschließt.

%Vor%

Wenn ich renne:

%Vor%

Ich bekomme ausgegeben:

%Vor%

Wenn ich jedoch versuche, die Parameter wiederherzustellen (auch ohne das Objekt zu töten): tfl.fit( train_X, train_Y , load = True)

Ich bekomme seltsame Ergebnisse. Zuallererst entspricht der geladene Wert nicht dem gespeicherten Wert.

%Vor%

Was ist der richtige Weg zu laden, und wahrscheinlich zuerst die gespeicherten Variablen zu überprüfen?

    
Dima Lituiev 28.12.2015, 20:11
quelle

1 Antwort

9

TL; DR: Sie sollten versuchen, diese Klasse so zu überarbeiten, dass self.create_network() (i) nur einmal aufgerufen wird und (ii) bevor die tf.train.Saver() erstellt wird.

Hier gibt es zwei heikle Probleme, die auf die Codestruktur und das Standardverhalten von tf.train.Saver Konstruktor . Wenn Sie einen Sparer ohne Argumente konstruieren (wie in Ihrem Code), sammelt er den aktuellen Satz von Variablen in Ihrem Programm und fügt dem Graphen Operationen zum Speichern und Wiederherstellen hinzu. Wenn Sie in Ihrem Code tflasso() aufrufen, wird ein Sparer erstellt, und es wird keine Variablen geben (weil create_network() noch nicht aufgerufen wurde). Daher sollte der Prüfpunkt leer sein.

Das zweite Problem ist, dass - standardmäßig - das Format eines gespeicherten Checkpoints eine Map aus dem name Eigenschaft einer Variablen auf ihren aktuellen Wert. Wenn Sie zwei Variablen mit demselben Namen erstellen, werden sie automatisch von TensorFlow: "unifiziert":

%Vor%

Wenn Sie self.create_network() im zweiten Aufruf von tfl.fit() aufrufen, haben die Variablen alle einen anderen Namen als die Namen, die im Prüfpunkt gespeichert sind - oder wären sie vorhanden gewesen, wenn der Sparer dies getan hätte wurde nach dem Netzwerk erstellt. (Sie können dieses Verhalten vermeiden, indem Sie dem Sparerkonstruktor ein% c% n% s-Verzeichnis übergeben, aber das ist normalerweise ziemlich peinlich.)

Es gibt zwei Hauptumgehungen:

  1. Erstellen Sie bei jedem Aufruf von Variable das gesamte Modell neu, indem Sie ein neues tflasso.fit() definieren und dann in diesem Diagramm das Netzwerk erstellen und ein tf.Graph erstellen.

  2. EMPFOHLEN Erstellen Sie das Netzwerk, dann tf.train.Saver im Konstruktor tf.train.Saver , und verwenden Sie dieses Diagramm bei jedem Aufruf von tflasso erneut. Beachten Sie, dass Sie möglicherweise etwas mehr tun müssen, um die Dinge zu reorganisieren (insbesondere bin ich nicht sicher, was Sie mit tflasso.fit() und self.X machen), aber es sollte möglich sein, dies mit Platzhalter und füttern.

mrry 28.12.2015, 20:59
quelle

Tags und Links