简体   繁体   中英

Print statement inside TensorFlow function

I am new to TensorFlow and I am trying to print the shape of a vector inside a function which will be called from a TensorFlow session.

The problem is that this line (showed commented out) is only executed when this function template is initially defined (and not at every iteration during a TensorFlow session). How do I add a print statement such that it is called at every TensorFlow iteration?

def Q(X):
    # f_debug.write('Q(X) :: X.shape :: ' + str(X.shape) + '\n')

    h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
    z = tf.matmul(h, Q_W2) + Q_b2
    return z

This is an important point to note and a common confusion in TF. That function will NOT be called by tensorflow in a session, no python function will, with the exception of tf.py_func , which could be a workaround to your problem.

Tensorflow calls your function Q only to get the symbolic operations and then adds those operations to the dependency graph. During a session, the dependency graph is all that is relied upon to perform computations. Even if you are using a tf.while , tf.cond , or other control flow operation. None of these call python during a session, they just loop over elements in the dependency graph as you defined.

In general there's no good way to stop the execution of tensorflow mid-graph execution short of using the Tensorflow Debugger (which is not hard to configure at all). But as a workaround you might get away with defining a tf.py_func python function. This function marshals a tensor into a python object and calls python during session execution (it's not efficient or anything, but it's handy in certain cases).

You may need to use with tf.control_dependencies(...): to force your tf.py_func operation to run (since it wouldn't have any dependency if it had nothing but a print statement inside of it).

Disclamer: I haven't used tf.py_func this way nor was it built with this intention.

This is a TensorFlow 1 specific explanation. I'm sure Eager Execution in TF 2 changes some of these things. See docs for tf.print .

I have not had any luck using regular print() statements inside py_func functions. You can do print statements "inside" the execution of your TF graph (py_func or otherwise) using tf.print . I'm only familiar with its use in TF 1, but there it works by creating a new print operation and adding it to the TF graph:

def Q(X):
    # Note that at least sometimes, X.shape can be resolved before the
    # graph is executed, so you may only need this for the *value* of X

    # Create a new print op
    printop = tf.print('Q:', X)

    # Force printop to be added to the graph by setting it as a 
    # dependency for at least one operation that will be run
    with tf.control_dependencies([printop]):
        h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
        z = tf.matmul(h, Q_W2) + Q_b2
        return z

I have run into some situations where this didn't work as I'd hoped (I believe due to multithreading). It also can be tricky if you can't find a node of the graph to attach the printop operation to--if you are trying to print something on the last line of a function, for example.

But this seems to work much of the time for me.

Docs on tf.print() .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM