[英]Pretrained NN Finetuning with Keras. How to freeze Batch Normalization?
So I didnt write my code in tf.keras and according to this tutorial for finetuning with a pretrained NN: https://keras.io/guides/transfer_learning/#freezing-layers-understanding-the-trainable-attribute ,所以我没有在 tf.keras 中编写我的代码,并根据本教程使用预训练的 NN 进行微调: https://keras.io/guides/transfer_learning/#freezing-layers-understanding-the-trainable-attribute ,
I have to set the parameter training=False
when calling the pretrained model, so that when I later unfreeze for finetuning, Batch Normalization doesnt destroy my model .在调用预训练的 model 时,我必须设置参数training=False
,以便稍后解冻进行微调时,批量标准化不会破坏我的 model 。 But how do I do that in keras (Remember: I didnt write it in tf.keras).但是我如何在 keras 中做到这一点(记住:我没有在 tf.keras 中写它)。 Is it even necessary in keras to do that? keras 甚至有必要这样做吗?
The code:编码:
def baseline_model():
pretrained_model = Xception(include_top=False, weights="imagenet")
for layer in pretrained_model.layers:
layer.trainable = False
general_input = Input(shape=(256, 256, 3))
x = pretrained_model(general_input,training=False)
x = GlobalAveragePooling2D()(x)
...
Gives me the error, when calling model = baseline_model()
:在调用model = baseline_model()
时给我错误:
TypeError: call() got an unexpected keyword argument 'training'
How do I do that best?我怎样才能做到最好? I tried rewriting everything in tf.keras, but theres errors popping up everyhwere when I tried to do it...我尝试重写 tf.keras 中的所有内容,但是当我尝试这样做时,到处都会弹出错误...
EDIT: My keras version is 2.3.1 and tensorflow 2.2.0 .编辑:我的 keras 版本是 2.3.1 和 tensorflow 2.2.0 。
EDITED my previous answer after doing some additional research:在做了一些额外的研究后编辑了我以前的答案:
I did some reading and it seems like there is some trickery in how BatchNorm layer behaves when frozen.我做了一些阅读,似乎 BatchNorm 层在冻结时的行为有一些技巧。 This is a good thread talking about it: github.com/keras-team/keras/issues/7085 seems like training=false
parameter is necessary to correctly freeze BatchNorm layer and it was added in Keras 2.1.3, so my advice for you is to make sure your Keras/TF version is higher这是一个很好的话题:github.com/keras-team/keras/issues/7085 似乎training=false
参数是正确冻结 BatchNorm 层所必需的,它是在 Keras 2.1.3 中添加的,所以我给你的建议是为了确保你的 Keras/TF 版本更高
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.