简体   繁体   English

Tensorflow Keras 中的 model.fit() 方法中的 validation_data 是否必须是一个元组?

[英]does validation_data in model.fit() method in Tensorflow Keras have to be a tuple?

I'm implementing a complicated loss function so I use a custom layer to pass the loss.我正在实施复杂的损失 function 所以我使用自定义层来传递损失。 Something like:就像是:

class SIAMESE_LOSS(Layer):
    def __init__(self, **kwargs):
        super(SIAMESE_LOSS, self).__init__(**kwargs)

    @staticmethod
    def mmd_loss(source_samples, target_samples):
        return mmd(source_samples, target_samples)

    @staticmethod
    def regression_loss(pred, labels):
        return K.mean(mae(pred, labels))

    def call(self, inputs, **kwargs):
        source_labels = inputs[0]
        target_labels = inputs[1]
        source_pred = inputs[2]
        target_pred = inputs[3]
        source_samples = inputs[4]
        target_samples = inputs[5]
        source_loss = self.regression_loss(source_pred, source_labels)
        target_loss = self.regression_loss(target_pred, target_labels)
        mmd_loss = self.mmd_loss(source_samples, target_samples)
        self.add_loss(source_loss)
        self.add_loss(target_loss)
        self.add_loss(mmd)
        self.add_metric(source_loss, aggregation='mean', name='source_mae')
        self.add_metric(target_loss, aggregation='mean', name='target_mae')
        self.add_metric(mmd_loss, aggregation='mean', name='MMD')
        return mmd_loss+target_loss+source_loss

So the labels are sent to the model as inputs.因此,标签将作为输入发送到 model。
Therefore fitting the model will be like:因此拟合 model 将是这样的:

        history = model.fit(
            x=[train_data_s, train_data_t, self.train_labels, self.train_data_t],
            y=None,
            batch_size=self.batch_size,
            epochs=base_epochs,
            verbose=2,
            callbacks=cp_callback,
            validation_data=[val_data_s, val_data_t, self.val_labels, self.val_labels_t],
            shuffle=True
        )

However, according to the official document in Tensorflow, validation_data should be:但是根据Tensorflow中的官方文档,validation_data应该是:

Data on which to evaluate the loss and any model metrics at the end of each epoch.用于评估损失的数据以及每个纪元结束时的任何 model 指标。 The model will not be trained on this data. model 不会接受此数据的培训。 validation_data will override validation_split. validation_data 将覆盖 validation_split。 validation_data could be: tuple (x_val, y_val) of Numpy arrays or tensors tuple (x_val, y_val, val_sample_weights) of Numpy arrays dataset For the first two cases, batch_size must be provided. validation_data 可以是:Numpy arrays 的元组 (x_val, y_val) 或 Numpy arrays 数据集的张量元组 (x_val, y_val, val_sample_weights) 对于前两种情况,必须提供 batch_size。 For the last case, validation_steps could be provided.对于最后一种情况,可以提供 validation_steps。 Note that validation_data does not support all the data types that are supported in x, eg, dict, generator or keras.utils.Sequence.请注意,validation_data 不支持 x 中支持的所有数据类型,例如 dict、generator 或 keras.utils.Sequence。

There's no 'label' that should be passed since they're already handled by the model as inputs.没有应该传递的“标签”,因为它们已经由 model 作为输入处理。 How can I solve the problem if I still wanna use validation data?如果我仍然想使用验证数据,我该如何解决问题?

to write your own loss you need to inherit from class Loss and then implement your loss calculation in the init and call methods.要编写您自己的损失,您需要继承 class Loss,然后在 init 和 call 方法中实现您的损失计算。 https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss https://www.tensorflow.org/api_docs/python/tf/keras/losses/损失

so you dont need to train without passing y in model.fit()所以你不需要在不通过 model.fit() 的情况下进行训练

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 “validation_data 将覆盖validation_split”是什么意思。 在 keras model.fit 文档中 - What is the meaning of "validation_data will override validation_split." in keras model.fit documentation Keras model.fit()与tf.dataset API + validation_data - Keras model.fit() with tf.dataset API + validation_data 如何将validation_data 传递给Model.fit + Dataset? - How to pass validation_data to Model.fit + Dataset? Tensorflow (Keras API) `model.fit` 方法返回“Failed to convert object of type<class 'tuple'> 张量”错误</class> - Tensorflow (Keras API) `model.fit` method returns “Failed to convert object of type <class 'tuple'> to Tensor” error Keras:不能使用 ImageDataGenerator 作为 model.fit 中的验证数据 - Keras: cannot use ImageDataGenerator as validation data in model.fit sklearn fit 方法中的 Validation_data 参数 - Validation_data argument in sklearn fit method Tensorflow Keras:在 model.fit() 上进行训练 - Tensorflow Keras: Training Holting on model.fit() keras model.fit函数需要哪种数据格式? - Which data format does keras model.fit function need? 在 Keras model.fit 中将训练数据指定为元组 (x, y) 的正确方法,具有多个输入和输出 - Correct way to specify training data as tuple (x, y) in Keras model.fit with multiple inputs and outputs 无法在Keras中将validation_data传递给具有多个输入的模型 - Can't pass validation_data to a model with several inputs in Keras
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM