简体   繁体   中英

How to include value of input feature with the prediction output in tensorflow?

I want to get the Input request included into prediction output in tensorflow.

I have 10 feature column, also i am able to get a prediction of each of the request from test.csv.

How do i get a prediction response including a request in that object, below is my serving input

def csv_serving_input_fn(): #Build the Serving Input csv_row = tf.placeholder(shape=[None], dtype=tf.string) features,label = _decode_csv(csv_row) features.pop(metadata.LABEL_COLUMN) return tf.estimator.export.ServingInputReceiver(features, {'csv_row': csv_row})

`

I found a solution by updating my custom estimator, and thus by not changing the serving_input_fn .

Here a snippet of the end of my estimator

def ...<custom_estimator>(features, labels, mode, params):
    # generate prediction
    ....
    # Calculate the RMSE/LOSS
    ....
    # Set the return
    predictions_dict = {"predicted": predictions}

    # 4. return EstimatorSpec
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs={
            'predictions': tf.estimator.export.PredictOutput(predictions_dict)}
    )

In fact, you have to update the tf.estimator.export.PredictOutput I did that in the set the return part, I added the inputs in the predictions_dict , and it works

def ...<custom_estimator>(features, labels, mode, params):
    # generate prediction
    ....
    # Calculate the RMSE/LOSS
    ....
    # Set the return
    predictions_dict = {"inputs":features[TIMESERIES_COL],"predicted": predictions}

    # 4. return EstimatorSpec
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        export_outputs={
            'predictions': tf.estimator.export.PredictOutput(predictions_dict)}
    )

Good prediction!!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM