Tensorflow: Wenn tf.expand_dims benutzt wird?

8

Tensorflow-Lernprogramme enthalten die Verwendung von tf.expand_dims , um einem Tensor eine "Batch-Dimension" hinzuzufügen. Ich habe die Dokumente für diese Funktion gelesen, aber es ist immer noch ziemlich mysteriös für mich. Weiß jemand genau unter welchen Umständen das verwendet werden muss?

Mein Code ist unten. Meine Absicht ist es, einen Verlust basierend auf der Entfernung zwischen den vorhergesagten und den tatsächlichen Behältern zu berechnen. (Z. B. wenn predictedBin = 10 und truthBin = 7 dann binDistanceLoss = 3 ).

%Vor%

Muss ich in diesem Fall tf.expand_dims auf predictedBin und binDistanceLoss anwenden? Vielen Dank im Voraus.

    
Ron Cohen 18.08.2016, 01:57
quelle

2 Antworten

19

expand_dims fügt keine Elemente in einem Tensor hinzu oder reduziert sie, sondern ändert nur die Form, indem 1 zu den Dimensionen hinzugefügt wird. Zum Beispiel könnte ein Vektor mit 10 Elementen als 10x1-Matrix behandelt werden.

Die Situation, in der ich expand_dims benutzt habe, ist, als ich versuchte, ein ConvNet zu erstellen, um Graustufenbilder zu klassifizieren. Die Graustufenbilder werden als Matrix der Größe [320, 320] geladen. % Co_de% erfordert jedoch, dass die Eingabe tf.nn.conv2d ist, wobei die Dimension [batch, in_height, in_width, in_channels] in meinen Daten fehlt, die in diesem Fall in_channels sein sollte. Also habe ich 1 verwendet, um eine weitere Dimension hinzuzufügen.

In Ihrem Fall glaube ich nicht, dass Sie expand_dims brauchen.

    
Da Tong 18.08.2016, 02:22
quelle
9

Um zu Da Tongs Antwort hinzuzufügen, möchten Sie vielleicht mehr als eine Dimension gleichzeitig erweitern. Wenn Sie z. B. die TensorFlow-Operation conv1d auf Vektoren mit Rang 1 ausführen, müssen Sie sie mit Rang drei versehen.

Die Ausführung von expand_dims mehrmals ist zwar lesbar, könnte jedoch einen gewissen Mehraufwand in den Berechnungsgraphen bringen. Sie können die gleiche Funktionalität in einem Einzeiler mit reshape erhalten:

%Vor%

HINWEIS: Wenn Sie den Fehler TypeError: Failed to convert object of type <type 'list'> to Tensor. erhalten, versuchen Sie, tf.shape(x)[0] anstelle von x.get_shape()[0] wie vorgeschlagen zu übergeben. hier .

Hoffe es hilft!
Prost, Andres

    
fr_andres 03.10.2017 23:45
quelle

Tags und Links