简体   繁体   中英

How to pad a TF Tensor with zero columns, given a *list* of columns

In tensorflow, I am trying to pad a tensor with zero columns, given a specific list of columns.

How can I implement it in tensorflow? I tried using tf.assign or tf.scatter_nd , but encountered some errros.

Here is a simple numpy implementation

a_np = np.array([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
a_padded = np.zeros((3, 7))
a_padded[:, columns] = a_np
print(a_padded)

## output ##

[[0. 1. 0. 0. 0. 2. 0.]
 [0. 3. 0. 0. 0. 4. 0.]
 [0. 5. 0. 0. 0. 6. 0.]]

I tried to do the same in tensorflow:

a = tf.constant([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
a_padded = tf.Variable(tf.zeros((3, 7)))
a_padded[:, columns].assign(a)

But this produces the following error:

TypeError: can only concatenate list (not "int") to list

I also tried using tf.scatter_nd :

a = tf.constant([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
shape = tf.constant((3, 7))
tf.scatter_nd(columns, a, shape)

But this produces the following error:

InvalidArgumentError: Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,7] updates: [3,2] [Op:ScatterNd]

Here is a solution:

tf.reset_default_graph()
a = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.int32)
columns = tf.constant([1, 5], dtype=tf.int32)
a_padded = tf.Variable(tf.zeros((3, 7), dtype=tf.int32))
indices = tf.stack(tf.meshgrid(tf.range(tf.shape(a_padded)[0]), columns, indexing='ij'), axis=-1)
update_cols = tf.scatter_nd_update(a_padded, indices, a)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(update_cols))

(OP here) I managed to find a solution using tf.scatter_nd . The trick was to align the dimensions of a, the columns and the output shape.

a_np = np.array([[1, 2],
                 [3, 4], 
                 [5, 6]])

# Note the Transpose on every line below
a = tf.constant(a_np.T) 
columns = tf.constant(np.array([[1, 5]]).T.astype('int32'))
shape = tf.constant((7, 3))
a_padded = tf.transpose(tf.scatter_nd(columns, a, shape))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM