I'm trying to fine tune resnet50 in half precision mode without success. It seems there are parts of model which are not compatible with float16
. Here is my code:
dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)
model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
and I get this error:
Traceback (most recent call last):
File "train_resnet.py", line 40, in <module>
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
return base_fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
return resnet50.ResNet50(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
output = self.call(inputs, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
epsilon=self.epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
epsilon=epsilon)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
data_format=tf_data_format)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
param_name=input_name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32
这是一个已报告的错误,升级到Keras==2.2.5
解决了该问题。
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.