簡體   English   中英

如何在 Keras 中使用 float16 微調 resnet50?

[英]How to fine tune resnet50 with float16 in Keras?

我試圖在半精度模式下微調 resnet50 沒有成功。 似乎有部分模型與float16不兼容。 這是我的代碼:

dtype='float16'
K.set_floatx(dtype)
K.set_epsilon(1e-4)

model = Sequential()
model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))

我收到這個錯誤:

Traceback (most recent call last):
  File "train_resnet.py", line 40, in <module>
    model.add(ResNet50(weights='imagenet', include_top=False, pooling='avg'))
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/__init__.py", line 28, in wrapper
    return base_fun(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/applications/resnet50.py", line 11, in ResNet50
    return resnet50.ResNet50(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras_applications/resnet50.py", line 231, in ResNet50
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
  File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 457, in __call__
    output = self.call(inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/keras/layers/normalization.py", line 185, in call
    epsilon=self.epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1864, in normalize_batch_in_training
    epsilon=epsilon)
  File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 1839, in _fused_normalize_batch_in_training
    data_format=tf_data_format)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py", line 1329, in fused_batch_norm
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 4488, in fused_batch_norm_v2
    name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 626, in _apply_op_helper
    param_name=input_name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py", line 60, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'scale' has DataType float16 not in list of allowed values: float32

這是一個已報告的錯誤,升級到Keras==2.2.5解決了該問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM