繁体   English   中英

如何使用tf.estimator保存张量流模型

[英]how to save tensorflow model with tf.estimator

我有以下示例代码使用tensorflow的estimator API训练和评估cnn mnist模型:

 def model_fn(features, labels, mode):
        images = tf.reshape(features, [-1, 28, 28, 1])
        model = Model()
        logits = model(images)

        predicted_logit = tf.argmax(input=logits, axis=1, output_type=tf.int32)

        if mode == tf.estimator.ModeKeys.PREDICT:
            probabilities = tf.nn.softmax(logits)

            predictions = {
                'predicted_logit': predicted_logit,
                'probabilities': probabilities
            }
            return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

        else:
            ...

    def mnist_train_and_eval(_):
        train_data, train_labels, eval_data, eval_labels, val_data, val_labels = get_mnist_data()

        # Create a input function to train
        train_input_fn = tf.estimator.inputs.numpy_input_fn(
            x= train_data,
            y=train_labels,
            batch_size=_BATCH_SIZE,
            num_epochs=1,
            shuffle=True)

        # Create a input function to eval
        eval_input_fn = tf.estimator.inputs.numpy_input_fn(
            x= eval_data,
            y=eval_labels,
            batch_size=_BATCH_SIZE,
            num_epochs=1,
            shuffle=False)

        # Create a estimator with model_fn
        image_classifier = tf.estimator.Estimator(model_fn=model_fn, model_dir=_MODEL_DIR)

        # Finally, train and evaluate the model after each epoch
        for _ in range(_NUM_EPOCHS):
            image_classifier.train(input_fn=train_input_fn)
            metrics = image_classifier.evaluate(input_fn=eval_input_fn)

如何使用estimator.export_savedmodel保存经过训练的模型以供以后推断? 我应该如何编写serving_input_receiver_fn?

非常感谢您的帮助!

使用输入功能字典创建函数。 占位符应与图片的形状匹配,第一尺寸为batch_size。

def serving_input_receiver_fn():
  x = tf.placeholder(tf.float32, [None, Shape])
  inputs = {'x': x}
  return tf.estimator.export.ServingInputReceiver(features=inputs, receiver_tensors=inputs)

或者您可以使用不需要dict映射的TensorServingInputReceiver

inputs = tf.placeholder(tf.float32, [None, 32*32*3])
tf.estimator.export.TensorServingInputReceiver(inputs, inputs)

此函数返回ServingInputReceiver新实例,该实例传递到export_savedmodeltf.estimator.FinalExporter

...
image_classifier.export_savedmodel(saved_dir, serving_input_receiver_fn)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM