繁体   English   中英

无法使用权​​重文件离线加载 keras resnet50 模型

[英]Failed to load keras resnet50 model offline using weight file

我想离线训练 keras 预训练的 resnet50 模型,但我无法加载模型。

当我设置weights='imagenet'时它起作用。 它会自动下载 imagenet 权重文件。

from keras.applications.resnet import ResNet50
base_model = ResNet50(include_top=False, weights='resnet', input_shape=(w,h,3),pooling='avg')

但是当我手动下载相同的权重文件并设置weights=resnet_weights_path ,它会抛出 ValueError。

(w,h) = 224,224
resnet_weights_path = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
base_model = ResNet50(include_top=False, weights=resnet_weights_path, input_shape=(w,h,3),pooling='avg')

ValueError:形状 (1, 1, 256, 512) 和 (512, 128, 1, 1) 不兼容。

完整追溯:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-7683562fa2b9> in <module>
      1 resnet_weights_path = '../input/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
      2 base_model = ResNet50(include_top=False, weights=resnet_weights_path,
----> 3                       pooling='avg')
      4 base_model.summary()

/opt/conda/lib/python3.6/site-packages/keras/applications/__init__.py in wrapper(*args, **kwargs)
     18         kwargs['models'] = models
     19         kwargs['utils'] = utils
---> 20         return base_fun(*args, **kwargs)
     21 
     22     return wrapper

/opt/conda/lib/python3.6/site-packages/keras/applications/resnet.py in ResNet50(*args, **kwargs)
     12 @keras_modules_injection
     13 def ResNet50(*args, **kwargs):
---> 14     return resnet.ResNet50(*args, **kwargs)
     15 
     16 

/opt/conda/lib/python3.6/site-packages/keras_applications/resnet_common.py in ResNet50(include_top, weights, input_tensor, input_shape, pooling, classes, **kwargs)
    433                   input_tensor, input_shape,
    434                   pooling, classes,
--> 435                   **kwargs)
    436 
    437 

/opt/conda/lib/python3.6/site-packages/keras_applications/resnet_common.py in ResNet(stack_fn, preact, use_bias, model_name, include_top, weights, input_tensor, input_shape, pooling, classes, **kwargs)
    411         model.load_weights(weights_path)
    412     elif weights is not None:
--> 413         model.load_weights(weights)
    414 
    415     return model

/opt/conda/lib/python3.6/site-packages/keras/engine/saving.py in load_wrapper(*args, **kwargs)
    490                 os.remove(tmp_filepath)
    491             return res
--> 492         return load_function(*args, **kwargs)
    493 
    494     return load_wrapper

/opt/conda/lib/python3.6/site-packages/keras/engine/network.py in load_weights(self, filepath, by_name, skip_mismatch, reshape)
   1228             else:
   1229                 saving.load_weights_from_hdf5_group(
-> 1230                     f, self.layers, reshape=reshape)
   1231             if hasattr(f, 'close'):
   1232                 f.close()

/opt/conda/lib/python3.6/site-packages/keras/engine/saving.py in load_weights_from_hdf5_group(f, layers, reshape)
   1235                              ' elements.')
   1236         weight_value_tuples += zip(symbolic_weights, weight_values)
-> 1237     K.batch_set_value(weight_value_tuples)
   1238 
   1239 

/opt/conda/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in batch_set_value(tuples)
   2958             `value` should be a Numpy array.
   2959     """
-> 2960     tf_keras_backend.batch_set_value(tuples)
   2961 
   2962 

/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/keras/backend.py in batch_set_value(tuples)
   3321     with ops.init_scope():
   3322       for x, value in tuples:
-> 3323         x.assign(np.asarray(value, dtype=dtype(x)))
   3324   else:
   3325     with get_graph().as_default():

/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
    817     with _handle_graph(self.handle):
    818       value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 819       self._shape.assert_is_compatible_with(value_tensor.shape)
    820       assign_op = gen_resource_variable_ops.assign_variable_op(
    821           self.handle, value_tensor, name=name)

/opt/conda/lib/python3.6/site-packages/tensorflow_core/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
   1108     """
   1109     if not self.is_compatible_with(other):
-> 1110       raise ValueError("Shapes %s and %s are incompatible" % (self, other))
   1111 
   1112   def most_specific_compatible_shape(self, other):

ValueError: Shapes (1, 1, 256, 512) and (512, 128, 1, 1) are incompatible

问题可能是由于 keras 版本。 我使用的当前 keras 版本是2.3.1
执行以下操作来解决问题:
1. 运行带有选项weights='imagenet'的代码。 它会自动下载重量文件。
2. 提供下载的权重文件的路径。

它们是形状不匹配,除了根据权重更改架构外无法解决,因为矢量形状不匹配会导致问题。

从这里下载权重并重试。 这些是由 keras 自身给出的权重。

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
                'releases/download/v0.2/'
                'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
                       'releases/download/v0.2/'
                       'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')

请将参数“by_name=True”添加到“model.load_weights()”中。 无论是在线模式还是离线模式,它都是解决问题的正确方法。 我采用离线模式,因为我的桌面有权重。

# Build model.
model = Model(inputs, x, name='resnet50')

# load weights
if weights == 'imagenet':
    if include_top:
        weights_path = WEIGHTS_PATH
    else:
        weights_path = WEIGHTS_PATH_NO_TOP
    # -model.load_weights(weights_path)
    model.load_weights(weights_path, by_name=True)

对于加载 Resnet50 以供离线使用的简单解决方案,您可以尝试通过设置参数weights ='imagenet'自动加载权weights ='imagenet'

from keras.applications.resnet import ResNet50

base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(w,h,3), pooling='avg')

使用保存模型

base_model.save("model_name.h5")

然后可以使用离线加载它作为模型(架构+权重)

from tensorflow.keras.models import load_model
resnet = load_model('model_name.h5')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM