簡體   English   中英

tensorflow python中張量的深拷貝

[英]Deep copy of tensor in tensorflow python

在我的一些代碼中,我使用 tensorflow 創建了一個神經網絡,並且可以訪問表示該網絡輸出的張量。 我想復制這個張量,這樣即使我對神經網絡進行更多訓練,我也可以訪問張量的原始值。

按照其他答案和 tensorflow 文檔,我嘗試了 tf.identity() 函數,但它似乎沒有做我需要的。 其他一些鏈接建議使用 tf.tile(),但這也無濟於事。 我不想使用 sess.run(),評估張量並將其存儲在其他地方。

這是一個描述我需要做的事情的玩具示例:

import tensorflow as tf
import numpy as np

t1 = tf.placeholder(tf.float32, [None, 1])
t2 = tf.layers.dense(t1, 1, activation=tf.nn.relu)
expected_out = tf.placeholder(tf.float32, [None, 1])

loss = tf.reduce_mean(tf.square(expected_out - t2))
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)

sess = tf.Session()

sess.run(tf.global_variables_initializer())

print(sess.run(t2, feed_dict={t1: np.array([1]).reshape(-1,1)}))
t3 = tf.identity(t2) # Need to make copy here
print(sess.run(t3, feed_dict={t1: np.array([1]).reshape(-1,1)}))

print("\nTraining \n")

for i in range(1000):
    sess.run(train_op, feed_dict={t1: np.array([1]).reshape(-1,1), expected_out: np.array([1]).reshape(-1,1)})

print(sess.run(t2, feed_dict={t1: np.array([1]).reshape(-1,1)}))
print(sess.run(t3, feed_dict={t1: np.array([1]).reshape(-1,1)}))

上面代碼的結果是t2t3具有相同的值。

[[1.5078927]]
[[1.5078927]]

Training

[[1.3262703]]
[[1.3262703]]

我想要的是讓t3保持其值不被復制。

[[1.5078927]]
[[1.5078927]]

Training

[[1.3262703]]
[[1.5078927]]

在此先感謝您的幫助。

您可以使用命名的tf.assign操作,然后通過Graph.get_operation_by_name僅運行該操作。 這不會獲取張量值,只是在圖上運行賦值操作。 請考慮以下示例:

import tensorflow as tf

a = tf.placeholder(tf.int32, shape=(2,))
w = tf.Variable([1, 2])  # Updated in the training loop.
b = tf.Variable([0, 0])  # Backup; stores intermediate result.
t = tf.assign(w, tf.math.multiply(a, w))  # Update during training.
tf.assign(b, w, name='backup')

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    x = [2, 2]
    # Emulate training loop:
    for i in range(3):
        print('w = ', sess.run(t, feed_dict={a: x}))
    # Backup without retrieving the value (returns None).
    print('Backup now: ', end='')
    print(sess.run(tf.get_default_graph().get_operation_by_name('backup')))
    # Train a bit more:
    for i in range(3):
        print('w = ', sess.run(t, feed_dict={a: x}))
    # Check the backed-up value:
    print('Backup: ', sess.run(b))  # Is [8, 16].

所以對於你的例子,你可以這樣做:

t3 = tf.Variable([], validate_shape=False)
tf.assign(t3, t2, validate_shape=False, name='backup')

我認為也許 copy.deepcopy() 可以工作......例如:

import copy 
tensor_2 = copy.deepcopy(tensor_1)

關於 deepcopy 的 Python 文檔: https ://docs.python.org/3/library/copy.html

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM