简体   繁体   English

使用model.fit时如何将'training'参数传递给tf,keras.Model

[英]how to pass in 'training' argument to tf,keras.Model when using model.fit

So I have this model written by the subclassing API, the call signature looks like call(x, training), where training argument is needed to differentiate between training and non-training when doing batchnorm and dropout. 因此,我有一个由子类API编写的模型,调用签名看起来像call(x,training),在执行batchnorm和dropout时,需要使用训练参数来区分训练和非训练。 How do I make the model forward pass know I am in training mode or eval mode when I use model.fit? 当我使用model.fit时,如何使模型前传知道我处于训练模式还是评估模式?

Thanks! 谢谢!

As far as i know, there is no argument for this. 据我所知,对此没有论据。 Model.fit simply trains the model on whatever training data provided, and at the end of each epoch evaluates the training on either provided validation data, OR by the use of validation_split. Model.fit只是根据提供的任何训练数据对模型进行训练,并且在每个纪元结束时,对提供的验证数据或通过使用validation_split来评估训练。

Actually, in the documentation https://www.tensorflow.org/beta/guide/keras/custom_layers_and_models , it says "Some layers, in particular the BatchNormalization layer and the Dropout layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a training (boolean) argument in the call method. 实际上,在文档https://www.tensorflow.org/beta/guide/keras/custom_layers_and_models中 ,它说:“某些层,尤其是BatchNormalization层和Dropout层,在训练和推理期间具有不同的行为。对于这些层,这是标准做法,在call方法中公开训练(布尔)参数。

By exposing this argument in call, you enable the built-in training and evaluation loops (eg fit) to correctly use the layer in training and inference." So I think the training argument is passed in automatically by keras. I tried to remove the default value for training argument and no errors were thrown, so it is very likely keras built-in loop did the thing. 通过在调用中公开此参数,您可以启用内置的训练和评估循环(例如,拟合)以在训练和推理中正确使用该图层。”因此,我认为keras会自动传递训练参数。训练参数的默认值,并且没有引发任何错误,因此很可能keras内置循环可以完成该操作。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 当训练数据是图像时,Keras model.fit() 中的“批次”是什么 - what is a “batch” in Keras model.fit() when training data are images Tensorflow Keras:在 model.fit() 上进行训练 - Tensorflow Keras: Training Holting on model.fit() 运行model.fit时Keras-TF Conv2D模型无响应 - Keras-TF Conv2D model unresponsive when running model.fit tf.keras.to_categorical TypeError 在 model.fit 期间 - tf.keras.to_categorical TypeError during model.fit 使用tf.train时,带有tf.dataset的Keras model.fit()失败 - Keras model.fit() with tf.dataset fails while using tf.train works fine 在 tf.keras 中的 model.fit 中,有没有办法将每个样本分批传递 n 次? - In model.fit in tf.keras, is there a way to pass each sample in a batch n times? tf.keras (RNN) 层在运行 model.fit() 时出现问题 - tf.keras (RNN) Layer issues when running model.fit() 在 Keras 和 ZCB20B802A3F021ECZE20 中使用 model.fit() 时如何使用 select 指标子集登录命令行 - How to select subset of metrics to log on commandline when using model.fit() in Keras & Tensorflow 梯度累积与自定义 model.fit TF.Keras? - Gradient Accumulation with Custom model.fit in TF.Keras? 如何在不指定目标的情况下在 Keras model.fit 中使用 tf.Dataset? - How to use tf.Dataset in Keras model.fit without specifying targets?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM