简体   繁体   English

Keras LSTM 在线学习

[英]Keras LSTM online learning

I have a sequence that is very predictable.我有一个非常可预测的序列。 Below, you can see a slice of it:下面,你可以看到它的一部分:

deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4])

Basically, it's a varying but still predictable number of 4s, followed by an 8, and by 28 every three 8s.基本上,它是一个不断变化但仍可预测的 4 次,然后是 8 次,然后每三个 8 次出现 28 次。

I want to build a very simple LSTM model for online prediction: each time a new number arrives, it is appended on the right of the deque.我想为在线预测构建一个非常简单的 LSTM 模型:每次有新数字到达时,它都会附加到双端队列的右侧。 Thus, the LSTM is trained on the old sequence comprised of [0:seq_length] elements of the deque, with the training target being the [seq_length] element.因此,LSTM 是在由双端队列的[0:seq_length]元素组成的旧序列上训练的,训练目标是[seq_length]元素。 Then, the window shifts and prediction is performed on the [1:seq_length+1] elements.然后,对 [1:seq_length+1] 元素执行窗口移位和预测。 At the end, the leftmost element of the deque is discarded.最后,deque 最左边的元素被丢弃。 My intuition tells me that this should make the network memorize the sequence.我的直觉告诉我,这应该让网络记住序列。

However, my network has been answering only 4. After a (long) while, surprisingly, it starts answering only 8, missing almost all of the time.但是,我的网络只回答了 4。(很长时间)后,令人惊讶的是,它开始只回答 8,几乎所有时间都没有回答。 And then, a (long) while later, it goes back to answering only 4.然后,过了(很长时间),它又回到只回答 4。

My model is structured as shown.我的模型结构如图所示。 Naturally, I've already experimented with different values for seq_length and lstm_cells , none of which gave me success.当然,我已经尝试了seq_lengthlstm_cells 的不同值,但没有一个让我成功。 These were from the latest run:这些来自最新的运行:

seq_length = 64  #Length of the sequence to be inserted into the LSTM
vocab_size = 4  #Size of the final dense layer of the model
lstm_cells = 16  #Size of the LSTM layer

model = Sequential()
model.add(LSTM(lstm_cells, input_shape=(seq_length, 1)))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

The following is how data is prepared, trained, and predicted on the model.以下是如何在模型上准备、训练和预测数据。 The variable sequence is the deque shown at the start of this post.变量序列是本文开头显示的双端队列。 I maintain a list vocab = [4,8,28] that is built on execution time as new numbers are seen, so vocab[i] translates class i into its corresponding number of the sequence.我维护了一个列表vocab = [4,8,28] ,它建立在看到新数字的执行时间上,因此vocab[i]将类i转换为其相应的序列编号。 I then create a dictionary legend to do the opposite.然后我创建一个字典图例来做相反的事情。 This is more or less the ongoing online loop:这或多或少是正在进行的在线循环:

while True:

    # Receives new number and puts it into the deque:
    sequence.append(generateNextNumber())

    # At this point, please note that the length of the deque is seq_length + 1.

    # Dictionary to convert numbers to classes:
    legend = dict([(v, k) for k, v in enumerate(vocab)])
    # Converts the deque into a list:
    seq_list = list(sequence)
    # Each iteration is comprised of 1 training and 1 prediction. These are the training sequence and target:
    train_seq = [ [legend[i]] for i in seq_list[:seq_length] ]
    train_target = legend[ seq_list[seq_length] ]
    # And the prediction sequence just shifts the window by 1:
    pred_seq = [ [legend[i]] for i in seq_list[1:] ]

    # Batches data into a batch of size 1:
    x = np.zeros((1, seq_length, 1))
    y = np.zeros((1, vocab_size))
    x[0,:] = train_seq
    y[0,:] = to_categorical( train_target, num_classes=vocab_size )
    # Online training:
    model.fit(x=x, y=y, batch_size=1, epochs=1, verbose=0)

    # Now that one training step is done, make a prediction:
    x_pred = np.zeros((1, seq_length, 1))
    x[0,:] = pred_seq
    predicted_onehot = model.predict(x_pred)
    # Avoids "index out of range" erros when the LSTM vocab is still being built:
    predicted_index = min(np.argmax(predicted_onehot), len(vocab)-1)
    predicted_number = vocab[ predicted_index ]
    # Reverts deque length to seq_length:
    sequence.popleft()

And, finally, this is an example output:最后,这是一个示例输出:

HIT! Current hit rate: 34.753665869071725 (predicted: 4, sequence was: deque([4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.75566735175926 (predicted: 4, sequence was: deque([4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 8

HIT! Current hit rate: 34.75660255820374 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4]))

HIT! Current hit rate: 34.758603766640086 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4]))

HIT! Current hit rate: 34.7606048523142 (predicted: 4, sequence was: deque([4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4]))

HIT! Current hit rate: 34.76260581523739 (predicted: 4, sequence was: deque([4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76460665542095 (predicted: 4, sequence was: deque([4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))

HIT! Current hit rate: 34.76660737287616 (predicted: 4, sequence was: deque([4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))

Predicted 4 but it was 28

HIT! Current hit rate: 34.767541707556425 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4]))

What is going so wrong?怎么了?

Thank you very much in advance.非常感谢您提前。

I had an issue in line:我有一个问题:

x[0,:] = pred_seq

which was supposed to be:这应该是:

x_pred[0,:] = pred_seq

Now everything is working more or less correctly.现在一切都或多或少地正常工作。 I'll still leave this question here since it offers some nice insights into LSTM online learning.我仍然会把这个问题留在这里,因为它为 LSTM 在线学习提供了一些很好的见解。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM