[英]Saving keras configuration with custom metric function to JSON
I am trying to save my configuration for a Keras model. 我正在尝试为Keras模型保存配置。 I would like to be able to read the configuration from the file to be able to reproduce the training.
我希望能够从文件中读取配置,以便能够重现培训。
Before implementing a custom metric in a function I could just do it the way shown below without the mean_pred
. 在函数中实现自定义指标之前,我可以按照下面显示的方式执行,而不使用
mean_pred
。 Now I am running into the problem TypeError: Object of type 'function' is not JSON serializable
. 现在我
TypeError: Object of type 'function' is not JSON serializable
了问题TypeError: Object of type 'function' is not JSON serializable
。
Here I read that it is possible to get the function name as string by custom_metric_name = mean_pred.__name__
. 在这里,我读到可以通过
custom_metric_name = mean_pred.__name__
将函数名称作为字符串。 I would like to not only be able to save the name, but to be able to save a reference to the function if possible. 我希望不仅能够保存名称,而且能够在可能的情况下保存对该功能的引用。
Perhaps I should as mentioned here also think about not just storing my configuration in the .py file but using ConfigObj
. 也许我应该像这里提到的那样考虑不只是将我的配置存储在.py文件中而是使用
ConfigObj
。 Unless this would solve my current problem I would implement this later. 除非这可以解决我当前的问题,否则我将在稍后实现。
Minimum working example of problem: 问题的最小工作示例:
import keras.backend as K
import json
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
config = {'epochs':500,
'loss':{'class':'categorical_crossentropy'},
'optimizer':'Adam',
'metrics':{'class':['accuracy', mean_pred]}
}
# Do the training etc...
config_filename = 'config.txt'
with open(config_filename, 'w') as f:
f.write(json.dumps(config))
Greatly appreciate help with this problem as well as other approaches to saving my configuration in the best way possible. 非常感谢这个问题的帮助以及以尽可能最好的方式保存配置的其他方法。
To solve my problem I saved the name of the function as a string in the config file and then extracted the function from a dictionary to use it as metrics in the model. 为了解决我的问题,我将函数的名称保存为配置文件中的字符串,然后从字典中提取函数以将其用作模型中的度量。 One could additionally use:
'class':['accuracy', mean_pred.__name__]
to save the name of the function as a string in the config. 还可以使用:
'class':['accuracy', mean_pred.__name__]
将函数的名称保存为配置中的字符串。 This does also work for multiple custom functions and for more keys to metrics (eg. define metrics for 'reg' like 'class' when doing regression and classification). 这也适用于多个自定义函数和更多指标的关键字(例如,在进行回归和分类时定义'reg'的指标,如'class'。
import keras.backend as K
import json
from collections import defaultdict
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
config = {'epochs':500,
'loss':{'class':'categorical_crossentropy'},
'optimizer':'Adam',
'metrics':{'class':['accuracy', 'mean_pred']}
}
custom_metrics= {'mean_pred':mean_pred}
metrics = defaultdict(list)
for metric_type, metric_functions in config['metrics'].items():
for function in metric_functions:
if function in custom_metrics.keys():
metrics[metric_type].append(custom_metrics[function])
else:
metrics[metric_type].append(function)
# Do the training, use metrics
config_filename = 'config.txt'
with open(config_filename, 'w') as f:
f.write(json.dumps(config))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.