Übersetzen eines TensorFlow LSTM in synapticjs

8

Ich arbeite an der Implementierung einer Schnittstelle zwischen einem bereits trainierten TensorFlow basic LSTM und einer Javascript-Version, die im Browser ausgeführt werden kann. Das Problem ist, dass in der gesamten Literatur, die ich gelesen habe, LSTMs als Mini-Netzwerke modelliert sind (nur Verbindungen, Knoten und Tore verwendend) und TensorFlow scheint viel mehr zu haben.

Die zwei Fragen, die ich habe, sind:

  1. Kann das TensorFlow-Modell leicht in eine konventionellere neurale Netzwerkstruktur übersetzt werden?

  2. Gibt es eine praktische Möglichkeit, die trainierbaren Variablen zuzuordnen, die TensorFlow dieser Struktur gibt?

Ich kann die "trainierbaren Variablen" aus TensorFlow herausholen, das Problem ist, dass sie anscheinend nur einen Wert für Bias pro LSTM-Knoten haben, wobei die meisten der Modelle, die ich gesehen habe, mehrere Verzerrungen für die Speicherzelle enthalten würden. die Eingänge und die Ausgabe.

    
A. Vickory 17.12.2015, 00:04
quelle

1 Antwort

8

Intern speichert die LSTMCell -Klasse die LSTM-Gewichtungen aus Gründen der Effizienz als eine große Matrix anstelle von 8 kleineren. Es ist ziemlich einfach, es horizontal und vertikal zu teilen, um zu der konventionelleren Darstellung zu gelangen. Es könnte jedoch einfacher und effizienter sein, wenn Ihre Bibliothek ähnliche Optimierungen durchführt.

Hier ist der relevante Code von BasicLSTMCell :

%Vor%

Die Funktion linear führt die Matrixmultiplikation durch, um die verkettete Eingabe und den vorherigen h -Zustand in 4 Matrizen von [batch_size, self._num_units] shape zu transformieren. Die lineare Transformation verwendet eine einzelne Matrix und Bias-Variablen, auf die Sie in der Frage Bezug nehmen. Das Ergebnis wird dann in verschiedene Gatter aufgeteilt, die von der LSTM-Transformation verwendet werden.

Wenn Sie die Transformationen für jedes Gatter explizit abrufen möchten, können Sie diese Matrix und Vorspannung in vier Blöcke aufteilen. Es ist auch ziemlich einfach, es von Grund auf mit 4 oder 8 linearen Transformationen zu implementieren.

    
Rafał Józefowicz 17.12.2015, 17:26
quelle