I am new to TensorFlow and I am trying to print the shape of a vector inside a function which will be called from a TensorFlow session.
The problem is that this line (showed commented out) is only executed when this function template is initially defined (and not at every iteration during a TensorFlow session). How do I add a print statement such that it is called at every TensorFlow iteration?
def Q(X):
# f_debug.write('Q(X) :: X.shape :: ' + str(X.shape) + '\n')
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
This is an important point to note and a common confusion in TF. That function will NOT be called by tensorflow in a session, no python function will, with the exception of tf.py_func
, which could be a workaround to your problem.
Tensorflow calls your function Q
only to get the symbolic operations and then adds those operations to the dependency graph. During a session, the dependency graph is all that is relied upon to perform computations. Even if you are using a tf.while
, tf.cond
, or other control flow operation. None of these call python during a session, they just loop over elements in the dependency graph as you defined.
In general there's no good way to stop the execution of tensorflow mid-graph execution short of using the Tensorflow Debugger (which is not hard to configure at all). But as a workaround you might get away with defining a tf.py_func
python function. This function marshals a tensor into a python object and calls python during session execution (it's not efficient or anything, but it's handy in certain cases).
You may need to use with tf.control_dependencies(...):
to force your tf.py_func
operation to run (since it wouldn't have any dependency if it had nothing but a print statement inside of it).
Disclamer: I haven't used tf.py_func
this way nor was it built with this intention.
This is a TensorFlow 1 specific explanation. I'm sure Eager Execution in TF 2 changes some of these things. See docs for tf.print .
I have not had any luck using regular print()
statements inside py_func functions. You can do print statements "inside" the execution of your TF graph (py_func or otherwise) using tf.print
. I'm only familiar with its use in TF 1, but there it works by creating a new print operation and adding it to the TF graph:
def Q(X):
# Note that at least sometimes, X.shape can be resolved before the
# graph is executed, so you may only need this for the *value* of X
# Create a new print op
printop = tf.print('Q:', X)
# Force printop to be added to the graph by setting it as a
# dependency for at least one operation that will be run
with tf.control_dependencies([printop]):
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
I have run into some situations where this didn't work as I'd hoped (I believe due to multithreading). It also can be tricky if you can't find a node of the graph to attach the printop
operation to--if you are trying to print something on the last line of a function, for example.
But this seems to work much of the time for me.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.