簡體   English   中英

使用在 tf.keras 中實現的自定義指標加載 keras model

[英]Loading keras model with custom metrics implemented in tf.keras

我在 tf.keras 中實現並訓練並保存了 model。 我需要通過 lrp 和其他僅在原始 keras 中支持的可視化技術來可視化一些層,因此我需要在 keras 中加載 model。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-65-08488a893466> in <module>
     10 
     11 with CustomObjectScope({'GlorotUniform': glorot_uniform(), "BinaryAccuracy":binary_accuracy}):
---> 12     model = load_model(os.getcwd() + "/models/saved_models_for_fusion/0_FusionVGGMnistToPS.h5")
     13 img_path = str(Path(os.getcwd() + "/models/scripts/datasets/parkinson_spiral_s/test/0/00027_w.cz.fnusa.1_1.svc.jpg"))
     14 img = load_img(img_path)

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_wrapper(*args, **kwargs)
    490                 os.remove(tmp_filepath)
    491             return res
--> 492         return load_function(*args, **kwargs)
    493 
    494     return load_wrapper

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in load_model(filepath, custom_objects, compile)
    582     if H5Dict.is_supported_type(filepath):
    583         with H5Dict(filepath, mode='r') as h5dict:
--> 584             model = _deserialize_model(h5dict, custom_objects, compile)
    585     elif hasattr(filepath, 'write') and callable(filepath.write):
    586         def load_function(h5file):

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\saving.py in _deserialize_model(h5dict, custom_objects, compile)
    367                       weighted_metrics=weighted_metrics,
    368                       loss_weights=loss_weights,
--> 369                       sample_weight_mode=sample_weight_mode)
    370 
    371         # Set optimizer weights.

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, **kwargs)
    209 
    210         # Save all metric attributes per output of the model.
--> 211         self._cache_output_metric_attributes(metrics, weighted_metrics)
    212 
    213         # Set metric attributes on model.

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training.py in _cache_output_metric_attributes(self, metrics, weighted_metrics)
    736                 output_shapes.append(list(output.shape))
    737         self._per_output_metrics = training_utils.collect_per_output_metric_info(
--> 738             metrics, self.output_names, output_shapes, self.loss_functions)
    739         self._per_output_weighted_metrics = (
    740             training_utils.collect_per_output_metric_info(

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in collect_per_output_metric_info(metrics, output_names, output_shapes, loss_fns, is_weighted)
    939         metrics_dict = OrderedDict()
    940         for metric in metrics:
--> 941             metric_name = get_metric_name(metric, is_weighted)
    942             metric_fn = get_metric_function(
    943                 metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\engine\training_utils.py in get_metric_name(metric, weighted)
    967         return metric
    968 
--> 969     metric = metrics_module.get(metric)
    970     return metric.name if hasattr(metric, 'name') else metric.__name__
    971 

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in get(identifier)
   1974     if isinstance(identifier, dict):
   1975         config = {'class_name': str(identifier), 'config': {}}
-> 1976         return deserialize(config)
   1977     elif isinstance(identifier, six.string_types):
   1978         return deserialize(str(identifier))

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\metrics.py in deserialize(config, custom_objects)
   1968                                     module_objects=globals(),
   1969                                     custom_objects=custom_objects,
-> 1970                                     printable_module_name='metric function')
   1971 
   1972 

d:\users\*\anaconda3\envs\tl\lib\site-packages\keras\utils\generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    138             if cls is None:
    139                 raise ValueError('Unknown ' + printable_module_name +
--> 140                                  ': ' + class_name)
    141         if hasattr(cls, 'from_config'):
    142             custom_objects = custom_objects or {}

ValueError: Unknown metric function: {'class_name': 'BinaryAccuracy', 'config': {'name': 'binary_accuracy', 'dtype': 'float32', 'threshold': 0.5}}

我之前在使用 GlorotUniform 時遇到過一些麻煩,而這個https://stackoverflow.com/a/53689541/5722894評論解決了這個問題。 我需要以某種方式創建自定義 object,但不知道如何創建。 我嘗試從 tf.keras 導入 BinaryAccuracy 並將其作為自定義 object 傳遞,但這也不起作用。

好吧,我沒有得到答案,但我可以說,維護兩個 keras 是一件壞事。 多個東西只適用於 tf.keras,甚至更多的東西只適用於非 tf keras(例如 iNNvestigate 框架)。

我最終重新訓練了幾個模型,只是為了使用 lpr / deep taylor vis。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM