简体   繁体   English

如何保存/加载一个tensorflow contrib.learn回归器?

[英]How do I save/load a tensorflow contrib.learn regressor?

I have a tensorflow contrib.learn.DNNRegressor that I have trained as part of the following code snippet: 我有一个tensorflow contrib.learn.DNNRegressor,我已将其训练为以下代码段的一部分:

regressor = tf.contrib.learn.DNNRegressor(feature_columns=fc, 
                                          hidden_units=hu_array, 
                                          optimizer=tf.train.AdamOptimizer(
                                                       learning_rate=0.001,
                                                    ),
                                          enable_centered_bias=False,
                                          activation_fn=tf.tanh,
                                          model_dir="./models/my_model/",
                                          )

regressor.fit(x=training_features, y=training_labels, steps=10000)

The trained network performs quite well, and I'd like to use it as a part of some other code, on another machine. 经过训练的网络性能很好,我想在另一台机器上将其用作其他代码的一部分。 I have tried copying over the models/my_model directory, and constructing a new DNNRegressor pointing just at the model_dir, but it requires that I supply feature_columns and hidden_units definitions. 我尝试复制过models / my_model目录,并构造一个仅指向model_dir的新DNNRegressor,但它要求我提供feature_columns和hidden_​​units定义。 Shouldn't that information be available via the snapshots stored in model_dir? 是否应该通过存储在model_dir中的快照获得该信息? Is there a better way to save/recover a trained model which is performing well, to be used as a predictor, without having to separately save the feature_columns and hidden_units? 有没有更好的方法来保存/恢复性能良好的训练模型,以用作预测变量,而不必分别保存feature_columns和hidden_​​units?

I came up with something workable- not ideal, but it gets the job done. 我想出了一些可行的方法-不理想,但是可以完成工作。 If anyone has a better idea, I am all ears. 如果有人有更好的主意,我会很高兴。

I converted my kwargs for DNNRegressor into a dict, and used the ** operator. 我将DNNRegressor的kwargs转换为dict,并使用**运算符。 Then I was able to pickle the kwargs dict, and reconstruct the DNNRegressor from that. 然后,我可以腌制kwargs字典,并从中重建DNNRegressor。 Eg: 例如:

reg_args = {'feature_columns': fc, 'hidden_units': hu_array, ...}
regressor = tf.contrib.learn.DNNRegressor(**reg_args)
pickle.dump(reg_args, open('reg_args.pkl', 'wb'))

Later on, I reconstruct via: 稍后,我通过以下方式进行重构:

reg_args = pickle.load(open('reg_args.pkl', 'rb'))
# On another machine and so my model dir path changed:
reg_args['model_dir'] = NEW_MODEL_DIR
regressor = tf.contrib.learn.DNNRegressor(**reg_args)

It worked well. 运行良好。 I'm sure there must be a better way but for now if someone is trying to figure out a workaround for tf.contrib.learn, this is a solution. 我确定肯定会有更好的方法,但是现在如果有人试图找出tf.contrib.learn的解决方法,这是一个解决方案。

When training 训练时

You call DNNRegressor(..., model_dir) and then call the fit() and evaluate() method. 您调用DNNRegressor(..., model_dir) ,然后调用fit()DNNRegressor(..., model_dir) evaluate()方法。

When testing 测试时

You call DNNRegressor(..., model_dir) and then can call predict() methods. 您调用DNNRegressor(..., model_dir) ,然后可以调用DNNRegressor(..., model_dir) predict()方法。 Your model will find a trained model in the model_dir and will load the trained model params. 您的模型将在model_dir找到经过训练的模型,并将加载经过训练的模型参数。

Reference 参考

Issue #3340 of TF TF第3340期

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM