简体   繁体   English

Keras/Python - 如果 RNN 是有状态的,则必须提供完整的 input_shape(包括批量大小)

[英]Keras/Python - If a RNN is stateful, a complete input_shape must be provided (including batch size)

I'm trying to implement a stateful RNN, however it keeps asking me for a "complete input_shape (including batch size)".我正在尝试实现一个有状态的 RNN,但是它一直要求我提供“完整的 input_shape(包括批量大小)”。 So I tried different things for the input_shape and input_batch_size arguments, neither of which seems to work.所以我为input_shapeinput_batch_size参数尝试了不同的东西,这两个似乎都不起作用。

Code:代码:

model=Sequential()

model.add(SimpleRNN(init='uniform',
  output_dim=80,
  input_dim=len(pred_frame.columns),
  stateful=True,
  batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),
  input_shape=(len(pred_frame.index),len(pred_frame.columns))))

model.add(Dense(output_dim=200,input_dim=len(pred_frame.columns),init="glorot_uniform"))

model.add(Dense(output_dim=1))

model.compile(loss="mse", class_mode='scalar', optimizer="sgd")

model.fit(X=predictor_train, y=target_train,
  batch_size=len(pred_frame.index),show_accuracy=True)

Traceback:追溯:

File "/Users/file.py", line 1483, in Pred
model.add(SimpleRNN(init='uniform',output_dim=80,input_dim=len(pred_frame.columns),stateful=True,batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),input_shape=(len(pred_frame.index),len(pred_frame.columns))))
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 194, in __init__
super(SimpleRNN, self).__init__(**kwargs)
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 97, in __init__
super(Recurrent, self).__init__(**kwargs)
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 43, in __init__
self.set_input_shape((None,) + tuple(kwargs['input_shape']))
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 141, in set_input_shape
self.build()
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 199, in build
self.reset_states()
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 221, in reset_states
'(including batch size).')
Exception: If a RNN is stateful, a complete input_shape must be provided (including batch size).

You need to provide only the batch_input_shape= parameter, and not the input_shape parameter.您只需要提供 batch_input_shape= 参数,而不是input_shape 参数。 Also, to avoid input shape errors, make sure the training data size is a multiple of batch_size.此外,为避免输入形状错误,请确保训练数据大小是 batch_size 的倍数。 And finally, if you are using validation splits, you have to be sure that both splits are also multiples of the batch_size.最后,如果您使用验证拆分,则必须确保两个拆分也是 batch_size 的倍数。

# ensure data size is a multiple of batch_size
data_size=data_size-data_size%batch_size
# ensure validation splits are multiples of batch_size
increment=float(batch_size)/len(data_size)
val_split=float(int(val_split/(increment))) * increment

In your definition of SimpleRNN , remove input_dim and input_shape , set:SimpleRNN的定义中,删除input_diminput_shape ,设置:

batch_input_shape = (Number_Of_sequences, Size_Of_Each_Sequence,
                     Shape_Of_Element_In_Each_Sequence) 

batch_input_shape should be a tuple of length at least 3. batch_input_shape应该是长度至少为 3 的元组。

If you passes your sequences one by one, set:如果您一一传递序列,请设置:

Number_Of_sequences = 1

If the size of your sequences is not fixed, set:如果序列的大小不固定,请设置:

Size_Of_Each_Sequence = None

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM