简体   繁体   中英

Tensorflow print in function

I have a function in a file neural_network.py that defines a loss function:

def loss(a, b):
    ...
    debug = tf.Print(a, [a], message = 'debug: ')
    debug.eval(session = ???)
    return tf.add(a, b)

To explain, somewhere in this function I want to print a tensor. However, I don't have any session declared in this function; my sessions are declared in another file called forecaster.py . Therefore, when I try to put tf.Print() in loss() , I can't because I don't know which session to eval with. Is there a way to solve this problem, either by using tf.Print() or other debug methods? Thanks!

tf.Print works as an identity function which returns the same tensor that you passed as the first parameter, having a side effect of printing the list of tensors specified as the second parameter.

So you should use as following:

def loss(a, b):
    ...
    a = tf.Print(a, [a], message = 'debug: ')
    return tf.add(a, b)

a will be printed each time tensor tf.add(a, b) is evaluated.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM