簡體   English   中英

在張量流中交換張量的元素

[英]Swap elements of tensor in tensorflow

我在嘗試交換長度可變的張量元素時遇到了令人驚訝的困難。 據我了解,切片賦值僅支持變量,因此在運行以下代碼時,出現錯誤ValueError: Sliced assignment is only supported for variables

def add_noise(tensor):
  length = tf.size(tensor)

  i = tf.random_uniform((), 0, length-2, dtype=tf.int32)
  aux = tensor[i]
  tensor = tensor[i].assign(tensor[i+1])
  tensor = tensor[i+1].assign(aux)

  return tensor

with tf.Session() as sess:
  tensor = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32)
  print sess.run(add_noise(tensor))

如何交換張量中的元素?

您可以使用 TensorFlow分散函數scatter_nd來交換tensor元素。 您還可以在單​​個scatter操作中實現多個交換。

tensor = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6], dtype=tf.int32)  # input
# let's swap 1st and 4th elements, and also 5th and 6th elements (in terms of 0-based indexing)
indices = tf.constant([[0], [4], [2], [3], [1], [5], [6]])  # indices mentioning the swapping pattern
shape = tf.shape(tensor)  # shape of the scattered_tensor, zeros will be injected if output shape is greater than input shape
scattered_tensor = tf.scatter_nd(indices, tensor, shape)

with tf.Session() as sess:
  print sess.run(scattered_tensor)
  # [0 4 2 3 1 6 5]

張量一旦被定義就是不可變的,另一方面,變量是可變的。 你需要的是一個 TensorFlow 變量。你應該改變這一行:

tensor = tf.convert_to_tensor(l, dtype=tf.int32)

到以下。

tensor = tf.Variable(l, dtype=tf.int32, name='my_variable')

如果您需要為矩陣的每一行交換不同的對,請嘗試此操作。

def batch_column_swap(base_tensor, from_col_index, to_col_index):
  batch_size = tf.shape(base_tensor)[0]
  batch_range = tf.range(batch_size)
  from_col_index = tf.stack([batch_range, from_col_index], axis=1)
  to_col_index = tf.stack([batch_range, to_col_index], axis=1)
  indices = tf.concat([to_col_index, from_col_index], axis=0)
  from_col = tf.gather_nd(base_tensor, from_col_index)
  to_col = tf.gather_nd(base_tensor, to_col_index)
  updates = tf.concat([from_col, to_col], axis=0)
  return tf.tensor_scatter_nd_update(base_tensor, indices, updates)

base_tensor = [[0, 1], [2, 3]]
from_col_idx = [0, 1]
to_col_idx = [1, 0]
with tf.Session() as sess:
  print(batch_column_swap(base_tensor, from_col_idx, to_col_idx))

# array([[1, 0],
#        [3, 2]], dtype=int32)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM