簡體   English   中英

當切片本身是張量流中的張量時如何進行切片分配

[英]how to do slice assignment while the slice itself is a tensor in tensorflow

我想在張量流中進行切片分配。 我知道我可以使用:

my_var = my_var[4:8].assign(tf.zeros(4))

基於此鏈接

如您在my_var[4:8]所見,我們在此處有特定的索引4、8,用於切片和賦值。

我的情況不同,我想基於張量進行切片,然后進行分配。

out = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

changed_tensor = [[8.3356,    0.,        8.457685 ],
                  [0.,        6.103182,  8.602337 ],
                  [8.8974,    7.330564,  0.       ],
                  [0.,        3.8914037, 5.826657 ],
                  [8.8974,    0.,        8.283971 ],
                  [6.103182,  3.0614321, 5.826657 ],
                  [7.330564,  0.,        8.283971 ],
                  [6.103182,  3.8914037, 0.       ]]

另外,這是sparse_indices張量,它是rows_tfcolumns_tf使整個索引需要更新(以防萬一:)

sparse_indices = tf.constant(
[[1 1]
 [2 1]
 [5 1]
 [1 2]
 [2 2]
 [5 2]
 [1 3]
 [2 3]
 [5 3]
 [1 2]
 [4 2]
 [6 2]
 [1 3]
 [4 3]
 [6 3]
 [2 2]
 [3 2]
 [6 2]
 [2 3]
 [3 3]
 [6 3]
 [2 2]
 [4 2]
 [4 2]])

我想做的是做這個簡單的任務:

out[rows_tf, columns_tf] = changed_tensor

為此,我正在這樣做:

out[rows_tf:column_tf].assign(changed_tensor)

但是,我收到此錯誤:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,8,3], [1,8,1], and [1] instead. [Op:StridedSlice] name: strided_slice/

這是預期的輸出:

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]
 [0.        0.        0.        0.       ]]

知道如何完成這項任務嗎?

先感謝您:)

這個示例(從tf文檔tf.scatter_nd_update擴展到此處 )應該會有所幫助。

你要先把你row_indices和column_indices結合成2D指數列表,這是indices來論證tf.scatter_nd_update 然后,您輸入了期望值列表,即updates

ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
indices = tf.constant([[0, 2], [2, 2]])
updates = tf.constant([1.0, 2.0])

update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  print sess.run(update)
Result:

[[ 0.  0.  1.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  2.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]]

專門針對您的數據,

ref = tf.Variable(tf.zeros(shape=[8,4], dtype=tf.float32))
changed_tensor = [[8.3356,    0.,        8.457685 ],
                  [0.,        6.103182,  8.602337 ],
                  [8.8974,    7.330564,  0.       ],
                  [0.,        3.8914037, 5.826657 ],
                  [8.8974,    0.,        8.283971 ],
                  [6.103182,  3.0614321, 5.826657 ],
                  [7.330564,  0.,        8.283971 ],
                  [6.103182,  3.8914037, 0.       ]]
updates = tf.reshape(changed_tensor, shape=[-1])
sparse_indices = tf.constant(
[[1, 1],
 [2, 1],
 [5, 1],
 [1, 2],
 [2, 2],
 [5, 2],
 [1, 3],
 [2, 3],
 [5, 3],
 [1, 2],
 [4, 2],
 [6, 2],
 [1, 3],
 [4, 3],
 [6, 3],
 [2, 2],
 [3, 2],
 [6, 2],
 [2, 3],
 [3, 3],
 [6, 3],
 [2, 2],
 [4, 2],
 [4, 2]])

update = tf.scatter_nd_update(ref, sparse_indices, updates)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  print sess.run(update)

Result:

[[ 0.          0.          0.          0.        ]
 [ 0.          8.3355999   0.          8.8973999 ]
 [ 0.          0.          6.10318184  7.33056402]
 [ 0.          0.          3.06143212  0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          8.45768547  8.60233688  0.        ]
 [ 0.          0.          5.82665682  8.28397083]
 [ 0.          0.          0.          0.        ]]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM