简体   繁体   中英

How to fine tune resnet50 with float16 in Keras?

I'm trying to fine tune resnet50 in half precision mode without success. It seems there are parts of model which are not compatible with float16 . Here is my code:

dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)

model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))

and I get this error:

Traceback (most recent call last):
  File "train_resnet.py", line 40, in <module>
    model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
    return base_fun(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
    return resnet50.ResNet50(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
    output = self.call(inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
    epsilon=self.epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
    epsilon=epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
    data_format=tf_data_format)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
    param_name=input_name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32

这是一个已报告的错误,升级到Keras==2.2.5解决了该问题。

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM