[英]How to fine tune resnet50 with float16 in Keras?
I'm trying to fine tune resnet50 in half precision mode without success.我试图在半精度模式下微调 resnet50 没有成功。 It seems there are parts of model which are not compatible with float16
.似乎有部分模型与float16
不兼容。 Here is my code:这是我的代码:
dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)
model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
and I get this error:我收到这个错误:
Traceback (most recent call last):
File "train_resnet.py", line 40, in <module>
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
return base_fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
return resnet50.ResNet50(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
output = self.call(inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
epsilon=self.epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
epsilon=epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
data_format=tf_data_format)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
param_name=input_name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32
这是一个已报告的错误,升级到Keras==2.2.5
解决了该问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.