I want to run an LSTM over a few different sequences on every batch and then join the last outputs. Here is what I've been trying:
from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed
num_sentences = 4
num_features = 3
num_time_steps = 5
inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)
shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)
This is giving me the following error:
Traceback (most recent call last):
File "test.py", line 12, in <module>
lstm_outputs = TimeDistributed(lstm_layer, input_shape=shape)(embedded)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 546, in __call__
self.build(input_shapes[0])
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/wrappers.py", line 94, in build
self.layer.build(child_input_shape)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/recurrent.py", line 702, in build
self.input_dim = input_shape[2]
IndexError: tuple index out of range
I tried omitting the input_shape
argument in TimeDistributed
, but it didn't change anything.
The input_shape
needs to be an argument of the LSTM layer, not TimeDistributed (which is a wrapper). By omitting it everything works fine for me:
from keras.layers import Dense, Input, LSTM, Embedding, TimeDistributed
num_sentences = 4
num_features = 3
num_time_steps = 5
inputs = Input([num_sentences, num_time_steps])
emb_layer = Embedding(10, num_features)
embedded = emb_layer(inputs)
lstm_layer = LSTM(4)
shape = [num_sentences, num_time_steps, num_features]
lstm_outputs = TimeDistributed(lstm_layer)(embedded)
#OUTPUT:
Using TensorFlow backend.
[Finished in 1.5s]
After trying michetonu's answer and having the same error, I realized my version of keras might be outdated. Indeed, was running keras 1.2, and the code ran fine on 2.0.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.