简体   繁体   English

使用在 tf.keras 中实现的自定义指标加载 keras model

[英]Loading keras model with custom metrics implemented in tf.keras

I implemented and trained & saved model in tf.keras.我在 tf.keras 中实现并训练并保存了 model。 I need to visualize some layers through lrp and other visualisation techniques that are only supported in original keras, therefore I need to load model in keras.我需要通过 lrp 和其他仅在原始 keras 中支持的可视化技术来可视化一些层,因此我需要在 keras 中加载 model。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-65-08488a893466> in <module>
     10 
     11 with CustomObjectScope({'GlorotUniform': glorot_uniform(), "BinaryAccuracy":binary_accuracy}):
---> 12     model = load_model(os.getcwd() + "/models/saved_models_for_fusion/0_FusionVGGMnistToPS.h5")
     13 img_path = str(Path(os.getcwd() + "/models/scripts/datasets/parkinson_spiral_s/test/0/00027_w.cz.fnusa.1_1.svc.jpg"))
     14 img = load_img(img_path)

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_wrapper(*args, **kwargs)
    490                 os.remove(tmp_filepath)
    491             return res
--> 492         return load_function(*args, **kwargs)
    493 
    494     return load_wrapper

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_model(filepath, custom_objects, compile)
    582     if H5Dict.is_supported_type(filepath):
    583         with H5Dict(filepath, mode='r') as h5dict:
--> 584             model = _deserialize_model(h5dict, custom_objects, compile)
    585     elif hasattr(filepath, 'write') and callable(filepath.write):
    586         def load_function(h5file):

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in _deserialize_model(h5dict, custom_objects, compile)
    367                       weighted_metrics=weighted_metrics,
    368                       loss_weights=loss_weights,
--> 369                       sample_weight_mode=sample_weight_mode)
    370 
    371         # Set optimizer weights.

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
    209 
    210         # Save all metric attributes per output of the model.
--> 211         self._cache_output_metric_attributes(metrics, weighted_metrics)
    212 
    213         # Set metric attributes on model.

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in _cache_output_metric_attributes(self, metrics, weighted_metrics)
    736                 output_shapes.append(list(output.shape))
    737         self._per_output_metrics = training_utils.collect_per_output_metric_info(
--> 738             metrics, self.output_names, output_shapes, self.loss_functions)
    739         self._per_output_weighted_metrics = (
    740             training_utils.collect_per_output_metric_info(

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in collect_per_output_metric_info(metrics, output_names, output_shapes, loss_fns, is_weighted)
    939         metrics_dict = OrderedDict()
    940         for metric in metrics:
--> 941             metric_name = get_metric_name(metric, is_weighted)
    942             metric_fn = get_metric_function(
    943                 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in get_metric_name(metric, weighted)
    967         return metric
    968 
--> 969     metric = metrics_module.get(metric)
    970     return metric.name if hasattr(metric, 'name') else metric.__name__
    971 

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in get(identifier)
   1974     if isinstance(identifier, dict):
   1975         config = {'class_name': str(identifier), 'config': {}}
-> 1976         return deserialize(config)
   1977     elif isinstance(identifier, six.string_types):
   1978         return deserialize(str(identifier))

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in deserialize(config, custom_objects)
   1968                                     module_objects=globals(),
   1969                                     custom_objects=custom_objects,
-> 1970                                     printable_module_name='metric function')
   1971 
   1972 

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\utils\generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    138             if cls is None:
    139                 raise ValueError('Unknown ' + printable_module_name +
--> 140                                  ': ' + class_name)
    141         if hasattr(cls, 'from_config'):
    142             custom_objects = custom_objects or {}

ValueError: Unknown metric function: {'class_name': 'BinaryAccuracy', 'config': {'name': 'binary_accuracy', 'dtype': 'float32', 'threshold': 0.5}}

I had some troubles before with the GlorotUniform and this https://stackoverflow.com/a/53689541/5722894 comment fixed that.我之前在使用 GlorotUniform 时遇到过一些麻烦,而这个https://stackoverflow.com/a/53689541/5722894评论解决了这个问题。 I need to somehow create that custom object, but no clue how.我需要以某种方式创建自定义 object,但不知道如何创建。 I tried to import BinaryAccuracy from tf.keras and just pass it as custom object, but that doesn't work either.我尝试从 tf.keras 导入 BinaryAccuracy 并将其作为自定义 object 传递,但这也不起作用。

Well, I didn't get the answer, but I can say that maintaining two keras is a bad thing to do.好吧,我没有得到答案,但我可以说,维护两个 keras 是一件坏事。 Multiple things works only on tf.keras and even more things works solely on non-tf keras (for instance iNNvestigate framework).多个东西只适用于 tf.keras,甚至更多的东西只适用于非 tf keras(例如 iNNvestigate 框架)。

I ended up retraining couple of models just for the purpose of using lpr / deep taylor vis.我最终重新训练了几个模型,只是为了使用 lpr / deep taylor vis。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM