簡體   English   中英

如何在TensorFlow中遍歷張量的元素?

[英]How to loop through the elements of a tensor in TensorFlow?

我想使用TensorFlow的tf.image.rot90()創建一個函數batch_rot90(batch_of_images) ,后者一次只獲取一個圖像,前者一次應獲取n個圖像(shape = [n,x,y ,F])。

因此,自然而然地,應該只遍歷批處理中的所有圖像,然后逐個旋轉它們。 在numpy中,它看起來像:

def batch_rot90(batch):
  for i in range(batch.shape[0]):
    batch_of_images[i] = rot90(batch[i,:,:,:])
  return batch

如何在TensorFlow中完成? 使用tf.while_loop我得到了他的幫助:

batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])    
def batch_rot90(batch, k, name=''):
      i = tf.constant(0)
      def cond(batch, i):
        return tf.less(i, tf.shape(batch)[0])
      def body(im, i):
        batch[i] = tf.image.rot90(batch[i], k)
        i = tf.add(i, 1)
        return batch, i  
      r = tf.while_loop(cond, body, [batch, i])
      return r

但是,不允許對im[i]進行賦值,並且我對r返回的結果感到困惑。

我意識到使用tf.batch_to_space()在這種特殊情況下可能有解決方法,但我相信也可以通過某種循環來實現。

tf中有一個map函數,它將起作用:

def batch_rot90(batch, k, name=''):
  fun = lambda x: tf.images.rot90(x, k = 1)
  return = tf.map_fn(fun, batch)

更新的答案:

x = tf.placeholder(tf.float32, shape=[2, 3])

def cond(batch, output, i):
    return tf.less(i, tf.shape(batch)[0])

def body(batch, output, i):
    output = output.write(i, tf.add(batch[i], 10))
    return batch, output, i + 1

# TensorArray is a data structure that support dynamic writing
output_ta = tf.TensorArray(dtype=tf.float32,
               size=0,
               dynamic_size=True,
               element_shape=(x.get_shape()[1],))
_, output_op, _  = tf.while_loop(cond, body, [x, output_ta, 0])
output_op = output_op.stack()

with tf.Session() as sess:
    print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]}))

我認為您應該考慮使用tf.scatter_update來批量更新一個圖像,而不是使用batch[i] = ... 請參閱此鏈接以獲取詳細信息。 對於您的情況,建議將正文的第一行更改為:

 tf.scatter_update(batch, i, tf.image.rot90(batch[i], k)) 

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM