简体   繁体   中英

Using TimeDistributed with recurrent layer in Keras

I want to run an LSTM over a few different sequences on every batch and then join the last outputs. Here is what I've been trying:

from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed

num_sentences = 4
num_features = 3
num_time_steps = 5

inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)

shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)

This is giving me the following error:

Traceback (most recent call last):
  File "test.py", line 12, in <module>
    lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)
  File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 546, in __call__
    self.build(input_shapes[0])
  File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/wrappers.py", line 94, in build
    self.layer.build(child_input_shape)
  File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/recurrent.py", line 702, in build
    self.input_dim = input_shape[2]
IndexError: tuple index out of range

I tried omitting the input_shape argument in TimeDistributed , but it didn't change anything.

The input_shape needs to be an argument of the LSTM layer, not TimeDistributed (which is a wrapper). By omitting it everything works fine for me:

from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed

num_sentences = 4
num_features = 3
num_time_steps = 5

inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)

shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer)(embedded)


#OUTPUT:
Using TensorFlow backend.
[Finished in 1.5s]

After trying michetonu's answer and having the same error, I realized my version of keras might be outdated. Indeed, was running keras 1.2, and the code ran fine on 2.0.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM