[英]how to do slice assignment while the slice itself is a tensor in tensorflow
我想在張量流中進行切片分配。 我知道我可以使用:
my_var = my_var[4:8].assign(tf.zeros(4))
基於此鏈接 。
如您在my_var[4:8]
所見,我們在此處有特定的索引4、8,用於切片和賦值。
我的情況不同,我想基於張量進行切片,然后進行分配。
out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
changed_tensor = [[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
另外,這是sparse_indices
張量,它是rows_tf
和columns_tf
使整個索引需要更新(以防萬一:)
sparse_indices = tf.constant(
[[1 1]
[2 1]
[5 1]
[1 2]
[2 2]
[5 2]
[1 3]
[2 3]
[5 3]
[1 2]
[4 2]
[6 2]
[1 3]
[4 3]
[6 3]
[2 2]
[3 2]
[6 2]
[2 3]
[3 3]
[6 3]
[2 2]
[4 2]
[4 2]])
我想做的是做這個簡單的任務:
out[rows_tf, columns_tf] = changed_tensor
為此,我正在這樣做:
out[rows_tf:column_tf].assign(changed_tensor)
但是,我收到此錯誤:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/
這是預期的輸出:
[[0. 0. 0. 0. ]
[0. 8.3356 0. 8.8974 ]
[0. 0. 6.103182 7.330564 ]
[0. 0. 3.0614321 0. ]
[0. 0. 3.8914037 0. ]
[0. 8.457685 8.602337 0. ]
[0. 0. 5.826657 8.283971 ]
[0. 0. 0. 0. ]]
知道如何完成這項任務嗎?
先感謝您:)
這個示例(從tf文檔tf.scatter_nd_update
擴展到此處 )應該會有所幫助。
你要先把你row_indices和column_indices結合成2D指數列表,這是indices
來論證tf.scatter_nd_update
。 然后,您輸入了期望值列表,即updates
。
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 1. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 2. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]]
專門針對您的數據,
ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
changed_tensor = [[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
updates = tf.reshape(changed_tensor, shape=[-1])
sparse_indices = tf.constant(
[[1, 1],
[2, 1],
[5, 1],
[1, 2],
[2, 2],
[5, 2],
[1, 3],
[2, 3],
[5, 3],
[1, 2],
[4, 2],
[6, 2],
[1, 3],
[4, 3],
[6, 3],
[2, 2],
[3, 2],
[6, 2],
[2, 3],
[3, 3],
[6, 3],
[2, 2],
[4, 2],
[4, 2]])
update = tf.scatter_nd_update(ref, sparse_indices, updates)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print sess.run(update)
Result:
[[ 0. 0. 0. 0. ]
[ 0. 8.3355999 0. 8.8973999 ]
[ 0. 0. 6.10318184 7.33056402]
[ 0. 0. 3.06143212 0. ]
[ 0. 0. 0. 0. ]
[ 0. 8.45768547 8.60233688 0. ]
[ 0. 0. 5.82665682 8.28397083]
[ 0. 0. 0. 0. ]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.