简体   繁体   中英

Multiple Layer RNN Tensorflow

I have two layer LSTM network. (config.n_input is 3, config.n_steps is 5)

I think this may be related to the shape of my inputs, but I'm not sure how to fix it, I tried changing the projecting of the LSTM so that they would be the same input size, but that didn't work.

      self.input_data = tf.placeholder(tf.float32, [None, config.n_steps, config.n_input], name='input')

       # Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
       self.initial_state = tf.placeholder(tf.float32, [None, 2*config.n_hidden], name='state')
       self.targets = tf.placeholder(tf.float32, [None, config.n_classes], name='target')

       _X = tf.transpose(self.input_data, [1, 0, 2])  # permute n_steps and batch_size
       _X = tf.reshape(_X, [-1, config.n_input]) # (n_steps*batch_size, n_input)

       input_cell = rnn_cell.LSTMCell(num_units=config.n_hidden, input_size=3, num_proj=300, forget_bias=1.0)
       print(input_cell.output_size)
       inner_cell = rnn_cell.LSTMCell(num_units=config.n_hidden, input_size=300)
       cells = [input_cell, inner_cell]
       cell = rnn.rnn_cell.MultiRNNCell(cells)

It returns the following error when attempt to run it.

tensorflow.python.pywrap_tensorflow.StatusNotOK: Invalid argument: Expected size[1] in [0, 0], but got 600
 [[Node: RNN/MultiRNNCell/Cell1/Slice = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](_recv_state_0/_3, RNN/MultiRNNCell/Cell1/Slice/begin, RNN/MultiRNNCell/Cell1/Slice/size)]]

any superior explanations of the error message? Or are there any ways to easily fix this?

Add num_proj to your initial state:

# Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
self.initial_state = tf.placeholder(tf.float32, [None, 2*config.n_hidden + 300], name='state')

This is quite an opaque error, and it might be a good idea idea for you to raise it on the TF GitHub issues page!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM