[英]Print statement inside TensorFlow function
我是 TensorFlow 的新手,我正在尝试在一个函数中打印向量的形状,该函数将从 TensorFlow 会话中调用。
问题是这一行(显示为注释掉)仅在最初定义此函数模板时执行(而不是在 TensorFlow 会话期间的每次迭代时)。 如何添加打印语句,以便在每次 TensorFlow 迭代时调用它?
def Q(X):
# f_debug.write('Q(X) :: X.shape :: ' + str(X.shape) + '\n')
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
这是一个需要注意的重要点,也是 TF 中常见的混淆。 该函数不会在会话中被 tensorflow 调用,除了 tf.py_func 之外,没有 python 函数会tf.py_func
,这可能是您问题的解决方法。
Tensorflow 调用您的函数Q
只是为了获取符号操作,然后将这些操作添加到依赖关系图中。 在会话期间,依赖图是执行计算所依赖的全部。 即使您正在使用tf.while
、 tf.cond
或其他控制流操作。 这些都不会在会话期间调用 python,它们只是按照您定义的方式循环依赖图中的元素。
一般来说,除了使用 Tensorflow Debugger(根本不难配置)之外,没有什么好的方法可以停止 tensorflow 中间图执行的执行。 但是作为一种解决方法,您可能会定义一个tf.py_func
python 函数。 这个函数将一个张量编组到一个 python 对象中,并在会话执行期间调用 python(它效率不高,但在某些情况下很方便)。
您可能需要使用with tf.control_dependencies(...):
来强制您的tf.py_func
操作运行(因为如果它里面只有一个打印语句,它就没有任何依赖项)。
免责声明:我没有以这种方式使用tf.py_func
也不是为了这个目的而构建的。
这是 TensorFlow 1 的具体解释。 我确信 TF 2 中的 Eager Execution 改变了其中的一些事情。 请参阅tf.print 的文档。
我在 py_func 函数中使用常规的print()
语句没有任何运气。 您可以使用tf.print
在 TF 图(py_func 或其他)的执行“内部”执行打印语句。 我只熟悉它在 TF 1 中的使用,但它的工作原理是创建一个新的打印操作并将其添加到 TF 图中:
def Q(X):
# Note that at least sometimes, X.shape can be resolved before the
# graph is executed, so you may only need this for the *value* of X
# Create a new print op
printop = tf.print('Q:', X)
# Force printop to be added to the graph by setting it as a
# dependency for at least one operation that will be run
with tf.control_dependencies([printop]):
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
z = tf.matmul(h, Q_W2) + Q_b2
return z
我遇到了一些情况,如我所希望这并不工作(我相信,由于多线程)。 如果您找不到图表的节点来附加printop
操作,这也可能会很棘手——例如,如果您试图在函数的最后一行打印某些内容。
但这似乎对我来说大部分时间都有效。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.