[英]Loading keras model with custom metrics implemented in tf.keras
我在 tf.keras 中實現並訓練並保存了 model。 我需要通過 lrp 和其他僅在原始 keras 中支持的可視化技術來可視化一些層,因此我需要在 keras 中加載 model。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-65-08488a893466> in <module>
10
11 with CustomObjectScope({'GlorotUniform': glorot_uniform(), "BinaryAccuracy":binary_accuracy}):
---> 12 model = load_model(os.getcwd() + "/models/saved_models_for_fusion/0_FusionVGGMnistToPS.h5")
13 img_path = str(Path(os.getcwd() + "/models/scripts/datasets/parkinson_spiral_s/test/0/00027_w.cz.fnusa.1_1.svc.jpg"))
14 img = load_img(img_path)
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_wrapper(*args, **kwargs)
490 os.remove(tmp_filepath)
491 return res
--> 492 return load_function(*args, **kwargs)
493
494 return load_wrapper
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_model(filepath, custom_objects, compile)
582 if H5Dict.is_supported_type(filepath):
583 with H5Dict(filepath, mode='r') as h5dict:
--> 584 model = _deserialize_model(h5dict, custom_objects, compile)
585 elif hasattr(filepath, 'write') and callable(filepath.write):
586 def load_function(h5file):
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in _deserialize_model(h5dict, custom_objects, compile)
367 weighted_metrics=weighted_metrics,
368 loss_weights=loss_weights,
--> 369 sample_weight_mode=sample_weight_mode)
370
371 # Set optimizer weights.
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
209
210 # Save all metric attributes per output of the model.
--> 211 self._cache_output_metric_attributes(metrics, weighted_metrics)
212
213 # Set metric attributes on model.
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in _cache_output_metric_attributes(self, metrics, weighted_metrics)
736 output_shapes.append(list(output.shape))
737 self._per_output_metrics = training_utils.collect_per_output_metric_info(
--> 738 metrics, self.output_names, output_shapes, self.loss_functions)
739 self._per_output_weighted_metrics = (
740 training_utils.collect_per_output_metric_info(
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in collect_per_output_metric_info(metrics, output_names, output_shapes, loss_fns, is_weighted)
939 metrics_dict = OrderedDict()
940 for metric in metrics:
--> 941 metric_name = get_metric_name(metric, is_weighted)
942 metric_fn = get_metric_function(
943 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in get_metric_name(metric, weighted)
967 return metric
968
--> 969 metric = metrics_module.get(metric)
970 return metric.name if hasattr(metric, 'name') else metric.__name__
971
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in get(identifier)
1974 if isinstance(identifier, dict):
1975 config = {'class_name': str(identifier), 'config': {}}
-> 1976 return deserialize(config)
1977 elif isinstance(identifier, six.string_types):
1978 return deserialize(str(identifier))
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in deserialize(config, custom_objects)
1968 module_objects=globals(),
1969 custom_objects=custom_objects,
-> 1970 printable_module_name='metric function')
1971
1972
d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\utils\generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
138 if cls is None:
139 raise ValueError('Unknown ' + printable_module_name +
--> 140 ': ' + class_name)
141 if hasattr(cls, 'from_config'):
142 custom_objects = custom_objects or {}
ValueError: Unknown metric function: {'class_name': 'BinaryAccuracy', 'config': {'name': 'binary_accuracy', 'dtype': 'float32', 'threshold': 0.5}}
我之前在使用 GlorotUniform 時遇到過一些麻煩,而這個https://stackoverflow.com/a/53689541/5722894評論解決了這個問題。 我需要以某種方式創建自定義 object,但不知道如何創建。 我嘗試從 tf.keras 導入 BinaryAccuracy 並將其作為自定義 object 傳遞,但這也不起作用。
好吧,我沒有得到答案,但我可以說,維護兩個 keras 是一件壞事。 多個東西只適用於 tf.keras,甚至更多的東西只適用於非 tf keras(例如 iNNvestigate 框架)。
我最終重新訓練了幾個模型,只是為了使用 lpr / deep taylor vis。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.