繁体   English   中英

Keras 自定义损失 function - 如何访问实际真值和预测

[英]Keras custom loss function - how to access actual truth values and predictions

我正在使用 Keras LSTM 进行时间序列预测。 我采用该系列的最后 n_input_steps 次出现并尝试预测向前迈出的一步。 例如,如果我的时间序列是 [1, 2, 3, 4] 并且 n_input_steps = 2,那么监督学习数据集将是:

[1,2]--> 3

[2,3]--> 4

因此,要预测的序列 (y_true) 将是 [3,4]。

现在我有一个 Keras model 来预测这种类型的系列:

        model = Sequential()
        model.add(LSTM(neurons, activation='relu', input_shape=(n_steps_in, 1)))
        model.add(RepeatVector(1))
        model.add(LSTM(neurons, activation='relu', return_sequences=True))
        model.add(TimeDistributed(Dense(1)))
        model.compile(optimizer='adam', loss=my_loss,run_eagerly=True)
        hist=model.fit(trainX, trainY, epochs=epochs, verbose=2,validation_data=(testX,testY))

而我的损失 function 是:

def my_loss(y_true,y_pred):
    print(kbe.shape(y_true))
    y_true_c = kbe.cast(y_true,'float32')
    y_pred_c = kbe.cast(y_pred,'float32')
    ytn = y_true_c.numpy()
    print(ytn.shape)
    # Do some complex calculation requiring the elements of y_true_c and y_pred_c. 
    # ...
    return result

据我所知,如果我调用model.fit(trainX, trainY,...)与 trainX 对应 [[1, 2], [2, 3]] (适当形状的数组)和 trainY 对应于 [ 3, 4],my_loss里面的y_true应该是对应[3, 4]的张量。 然而,这不是我发现的。 我的损失 function(张量和数组的形状)的打印 output 是:

tf.Tensor([32  1  1], shape=(3,), dtype=int32)
(32, 1, 1)

不管输入数组的大小。 如果我打印数组的值,它们就不会记住原始值。 即使我删除了 model 的所有层,保持裸露的顺序,我得到相同的形状。 因此,我完全迷失了。

根据上面的评论,我做了进一步的搜索,发现响应有一个默认的批量大小规则,正如 Jorge Avila 所指出的那样。 Keras 使用的默认长度为 32。 真实数据和预测数据以这种大小批量出现,所以我应该在调用model.fit()时使用batch_size=len(trainX) ) 。 此外,最重要的是,数据是打乱的,这就是为什么它变得更加混乱。 所以,我也必须在model.fit()中使用shuffle=False

However, as pointed out by Jakub, even with these modifications, my intended loss function will not work because Keras requires symbolic derivatives of the function, which cannot be achieved by having logic which requires the numpy values. 所以,我必须从头开始,另一个损失 function 可以被 Keras 接受。

根据您的数据,将 batch_sizes 保持在 32-1024 的倍数,因为 2** 总是有效,因为它很常见,但您不必使用 shuffle 适合,因为 TimeseriesGenrator 是需要进行更改的地方不适合.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM