简体   繁体   中英

How to use SessionRunHook to print tensor with tf.data.Dataset API?

I am using the tf.data.Dataset API and assigning names to operations within the closure that is passed to Dataset.map as in the below

import tensorflow as tf


def model_fn(features, mode):
    loss = tf.constant(1)
    train_op = tf.no_op()
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


def input_fn():
    dataset = tf.data.Dataset \
        .from_generator(lambda: (x*x for x in range(10)), tf.int32) \
        .map(lambda x: tf.identity(x, name='tokens_inside'))

    ret = dataset.make_one_shot_iterator().get_next()
    tf.identity(ret, 'tokens_outside')

    return ret


tf.logging.set_verbosity(tf.logging.INFO)

hooks = [
    tf.train.LoggingTensorHook(['tokens_outside'], every_n_iter=1),
    tf.train.LoggingTensorHook(['tokens_inside'], every_n_iter=1),
]

est = tf.estimator.Estimator(model_fn=model_fn, model_dir='mout')
est.train(input_fn=input_fn, hooks=hooks, max_steps=1)

When using tf.train.LoggingTensorHook to dump some of the values the second hook throws an exception:

I am getting errors like this:

KeyError: "The name 'tokens_inside:0' refers to a Tensor which does not exist. The operation, 'tokens_inside', does not exist in the graph."

I guess the Dataset operations create a new graph for each function? Is there a way to customize tf.train.LoggingTensorHook so it knows which graph to search for the named tensor?

在回调函数中使用 tf.add_to_collection(name, your_input) 和 get_collection(name)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM