I have the following TensorFlow tensors.
tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)
tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]
I wish to use the values stored in tensor 3 and tensor 4 to make a tuple and query the element at position given by the tuple in tensor 5. For example, let's say 0th element in tensor 3, that is tensor3[0]=5 and tensor4[0]=99. So the tuple becomes (5,99). I wish to look up the value of element (5,99) in tensor 5. I wish to do it for all elements in Tensor3 and Tensor4 in a batch processing manner. That is I do not want to loop over all values in the range of (len(Tensor3)). I did the following to achieve this.
tensor6 = tensor5[tensor3[0],tensor4[0]]
But tensor6 has the shape (255,255) where as I was hoping to get a tensor of shape (len(tensor3),len(tensor3)). I wanted to evaluate tensor5 at all possible locations in len(tensor3). That is at (0,0),...(1000,1000),....(2000,2000),...
. I am using TensorFlow version 1.12.0. How can I achieve this?
I have managed to get something working in Tensorflow v 1.12, but do let me know if it is the expected code:
import tensorflow as tf
print(tf.__version__)
import numpy as np
tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)
tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]
elems = (tensor3, tensor4)
a = tf.map_fn(lambda x: tensor5[x[0], x[1]], elems, dtype=tf.int32)
print(tf.Session().run(a))
Based on the comment below I'd like to add an explanation for the map_fn
used in the code. Since for
loops are not supported without eager_execution, map_fn
is (sort of) equivalent to for
loops.
A map_fn
has the following parameters: operation_performed
, input_arguments
, optional_dtype
. What happens under the hood is that a for
loop is run along the length of the values in input_arguments
(which must contain an iterable object) and then for each value obtained operation_performed
is performed. For further clarification please refer docs .
The names given to the arguments of the function is my way of interpreting them, as I'd like understand it, and is not given in the official docs. :)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.