繁体   English   中英

在递归循环期间分配给 TensorFlow 变量

[英]Assigning to a TensorFlow variable during a recursive loop

在 Tensorflow 1.9 中,我想创建一个网络,然后递归地将网络的输出(预测)反馈到网络的输入中。 在这个循环中,我想将网络所做的预测存储在一个列表中。

这是我的尝试:

    # Define the number of steps over which to loop the network
    num_steps = 5

    # Define the network weights
    weights_1 = np.random.uniform(0, 1, [1, 10]).astype(np.float32)
    weights_2 = np.random.uniform(0, 1, [10, 1]).astype(np.float32)

    # Create a variable to store the predictions, one for each loop
    predictions = tf.Variable(np.zeros([num_steps, 1]), dtype=np.float32)

    # Define the initial prediction to feed into the loop
    initial_prediction = np.array([[0.1]], dtype=np.float32)
    x = initial_prediction

    # Loop through the predictions
    for step_num in range(num_steps):
        x = tf.matmul(x, weights_1)
        x = tf.matmul(x, weights_2)
        predictions[step_num-1].assign(x)

    # Define the final prediction
    final_prediction = x

    # Start a session
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # Make the predictions
    last_pred, all_preds = sess.run([final_prediction, predictions])
    print(last_pred)
    print(all_preds)

这打印出来:

[[48.8769]]

[[0.]
 [0.]
 [0.]
 [0.]
 [0.]]

因此,虽然final_prediction的值看起来是正确的,但predictions的值并不是我所期望的。 似乎predictions从来没有真正分配到,尽管路线predictions[step_num-1].assign(x)

请有人向我解释为什么这不起作用,我应该做什么? 谢谢!

发生这种情况是因为assign和其他任何操作一样只是一个 TF 操作,因此仅在需要时才执行。 由于final_prediction的路径上没有任何东西依赖于赋值操作,而predictions只是一个变量,赋值永远不会被执行。

我认为最直接的解决方案是更换线路

predictions[step_num-1].assign(x)

经过

x = predictions[step_num-1].assign(x)

这是有效的,因为assign还返回它正在分配的值。 现在,要计算final_prediction TF 实际上需要“通过” assign操作,因此应该执行分配。

另一种选择是使用tf.control_dependencies ,这是一种在计算其他操作时“强制”TF 计算特定操作的方法。 但是,在这种情况下,它可能有点棘手,因为我们要强制执行的操作 ( assign ) 取决于在循环中计算的值,而且我不确定 TF 在这种情况下执行操作的顺序。 以下应该工作:

for step_num in range(num_steps):
    x = tf.matmul(x, weights_1)
    x = tf.matmul(x, weights_2)
    with tf.control_dependencies([predictions[step_num-1].assign(x)]):
        x = tf.identity(x)

我们使用tf.identity作为 noop 只是为了有一些东西可以用control_dependencies包装。 我认为这是两者之间更灵活的选择。 但是,它附带了文档中讨论的一些注意事项。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM