[英]does validation_data in model.fit() method in Tensorflow Keras have to be a tuple?
I'm implementing a complicated loss function so I use a custom layer to pass the loss.我正在实施复杂的损失 function 所以我使用自定义层来传递损失。 Something like:就像是:
class SIAMESE_LOSS(Layer):
def __init__(self, **kwargs):
super(SIAMESE_LOSS, self).__init__(**kwargs)
@staticmethod
def mmd_loss(source_samples, target_samples):
return mmd(source_samples, target_samples)
@staticmethod
def regression_loss(pred, labels):
return K.mean(mae(pred, labels))
def call(self, inputs, **kwargs):
source_labels = inputs[0]
target_labels = inputs[1]
source_pred = inputs[2]
target_pred = inputs[3]
source_samples = inputs[4]
target_samples = inputs[5]
source_loss = self.regression_loss(source_pred, source_labels)
target_loss = self.regression_loss(target_pred, target_labels)
mmd_loss = self.mmd_loss(source_samples, target_samples)
self.add_loss(source_loss)
self.add_loss(target_loss)
self.add_loss(mmd)
self.add_metric(source_loss, aggregation='mean', name='source_mae')
self.add_metric(target_loss, aggregation='mean', name='target_mae')
self.add_metric(mmd_loss, aggregation='mean', name='MMD')
return mmd_loss+target_loss+source_loss
So the labels are sent to the model as inputs.因此,标签将作为输入发送到 model。
Therefore fitting the model will be like:因此拟合 model 将是这样的:
history = model.fit(
x=[train_data_s, train_data_t, self.train_labels, self.train_data_t],
y=None,
batch_size=self.batch_size,
epochs=base_epochs,
verbose=2,
callbacks=cp_callback,
validation_data=[val_data_s, val_data_t, self.val_labels, self.val_labels_t],
shuffle=True
)
However, according to the official document in Tensorflow, validation_data should be:但是根据Tensorflow中的官方文档,validation_data应该是:
Data on which to evaluate the loss and any model metrics at the end of each epoch.用于评估损失的数据以及每个纪元结束时的任何 model 指标。 The model will not be trained on this data. model 不会接受此数据的培训。 validation_data will override validation_split. validation_data 将覆盖 validation_split。 validation_data could be: tuple (x_val, y_val) of Numpy arrays or tensors tuple (x_val, y_val, val_sample_weights) of Numpy arrays dataset For the first two cases, batch_size must be provided. validation_data 可以是:Numpy arrays 的元组 (x_val, y_val) 或 Numpy arrays 数据集的张量元组 (x_val, y_val, val_sample_weights) 对于前两种情况,必须提供 batch_size。 For the last case, validation_steps could be provided.对于最后一种情况,可以提供 validation_steps。 Note that validation_data does not support all the data types that are supported in x, eg, dict, generator or keras.utils.Sequence.请注意,validation_data 不支持 x 中支持的所有数据类型,例如 dict、generator 或 keras.utils.Sequence。
There's no 'label' that should be passed since they're already handled by the model as inputs.没有应该传递的“标签”,因为它们已经由 model 作为输入处理。 How can I solve the problem if I still wanna use validation data?如果我仍然想使用验证数据,我该如何解决问题?
to write your own loss you need to inherit from class Loss and then implement your loss calculation in the init and call methods.要编写您自己的损失,您需要继承 class Loss,然后在 init 和 call 方法中实现您的损失计算。 https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss https://www.tensorflow.org/api_docs/python/tf/keras/losses/损失
so you dont need to train without passing y in model.fit()所以你不需要在不通过 model.fit() 的情况下进行训练
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.