简体   繁体   English

如何在keras训练期间调用reset_states()?

[英]How to call reset_states() during training in keras?

I have sequential data which is .. 我有顺序数据,..

  • input dim of X are all same 输入dim的X都是一样的
  • sequence length of x is different among X X的序列长度不同

I am using LSTM, so I want to call reset_states for each x data ( x1 and x2 ). 我正在使用LSTM,所以我想为每个x数据( x1x2 )调用reset_states。 x1 and x2 are independent data, so I have to reset history of LSTM when I test x2 after x1 . x1x2是独立的数据,所以当我在x1之后测试x2时,我必须重置LSTM的历史记录。

My code is here. 我的代码在这里。 Should I use stateful option? 我应该使用stateful选项吗?

# input dimension is two
# but data length is differenct between x1 and y1
x1 = [[1,2],[3,3],[2,1],[2,4]] # x1 length == 4
y1 = [2,3,2,1]

x2 = [[3,2], [2,1]] # x2 length == 2
y2 = [2,4]

input_dim = 2
max_len = 4 # max(len(x1), len(x2)
max_y = 4 # y -> (1,2,3,4)

trainX = [x1, x2]
trainY = [y1, y2]

m = Sequential()
m.add(LSTM(128, 
      input_shape=(max_len, input_dim),
      activation='tanh', 
      return_sequences=True))
m.add(TimeDistributed(Dense(max_y, activation='softmax')))
m.compile(...)
m.fit(trainX, trainY, nb_epoch=10)

Edited 编辑

I found a stateful LSTM example. 我发现了一个有状态的LSTM示例。 But it calls reset_states() every epoch. 但它每个时代都会调用reset_states()。 What I want to do is calling every x . 我想做的是调用每个x https://github.com/fchollet/keras/blob/aff40d800891799dc9ed765617fcbfa665349d0d/examples/stateful_lstm.py https://github.com/fchollet/keras/blob/aff40d800891799dc9ed765617fcbfa665349d0d/examples/stateful_lstm.py

The link that you referred uses fit function with epochs = 1 . 您引用的链接使用fit epochs = 1函数。 I think you can either use train_on_batch() or use fit() function with on_batch_end() callback. 我想你可以使用train_on_batch()或使用fit()函数和on_batch_end()回调。 This way, you can reset states after each x (by setting appropriate batch size). 这样,您可以在每个x之后重置状态(通过设置适当的批量大小)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM