简体   繁体   中英

Reshape tensor of 3 rank to 2 rank in tensorflow to use it in fully connected layer

I did a lot of search and try but I have not found any solution for the next problem:

I am working on neural network models to test classification of sentences. These sentences are represented in [rows, words_encoded_by_word2vec] format. The first network - fully connected is done. In the second model I am trying to add a conv1d and a max_pool1d layers just before the dense layer. These layers expect a tensor in [batch_size, length, channel] format. Ok. This is not a problem and it is done. /I have only 1 channel/

However the connection between fully connected and max pool layers is extremely difficult to set up because of the unknown batch_size .

self.X = tf.placeholder(tf.float32, shape=(None, self.n_inputs, 1), name="input")

self.convolution = tf.nn.conv1d(self.X, self.filter, stride=50, adding="SAME")
self.max_pool = tf.layers.max_pooling1d(self.convolution, pool_size=2, strides=1, padding="SAME")
self.tensor_vector = tf.reshape(tensor=self.max_pool, shape=(-1, tf.shape(self.max_pool)[1]*tf.shape(self.max_pool)[2]))

This works but the dense layer does not accept it, and gives an error:

ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.

The specific dense layer:

hiddens.append(dropout(fully_connected(self.tensor_vector, layers[0], scope="hidden0",weights_initializer=tf.contrib.layers.xavier_initializer(),biases_initializer=tf.random_uniform_initializer(-0.1,0.1)), self.keep_prob, is_training=self.is_training))  

Are there any possibilities to reconcile the two format? Any helps will be appreciated :)

Tanks.

I found an answer that can be applied for my specific problem. tf.squeeze() can remove the 1 sized dimension (deprecated dimension). This is exactly the channel that I want to remove. This transformation did the trick for me.

However I think these shape problems are prevalent in tensorflow development so it would be nice if there were adapters for these problems.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM