简体   繁体   中英

How to update the a specific set of indices of a multi-dimensional tensor in Tensorflow

I have this multi-dimensional tensor of shape [1,32,32,155], of which I want to update the [:,:,:,0:27] indices.

In pytorch, one would do this simply with index assign ie [:,:,:,0:27] = [1,32,32,27]. Index assign is currently not supported in Tensorflow. Therefore, my first attempt was to do the following:

     feat_ch = tf.unstack(feat, axis=3)
     feat_ch[0:self.ncIn] = tf.unstack(upFeat, axis=3 )
     feat = tf.stack(feat_ch, axis=3)

feat_ch being the [1,32,32,155], and upFeat being the tensor [1,32,32,27]. The idea here was to collapse the feat_ch tensor on the channel dimension, such that I get a list of 155 entries with 1,32,32. And then doing the same with upFeat, and then replace the first 27 of the feat_ch list with the 27 of the upFeat. Finally, stacking them up to get the [1,32,32,155] shaped tensor again (this time with the 27 first channels updated)

However, I am not sure if it does what I want. So I began to investigate what other alternatives to update.

Tensorflow has a method tensor_scatter_nd_update, which seems to be exactly what I wanted. However, I find it hard to wrap my head around. What I have tried so far is:

    i1, i2, i3, i4 = tf.meshgrid(tf.range(1),
    tf.range(32), tf.range(32), tf.range(27) , indexing="ij") #shape [1,32,32,27]
    feat = tf.tensor_scatter_nd_update(feat, i1, upFeat)

The idea here was to create a mesh grid of the same shape and in such a way that each element corresponds to an index of the feat that I wish to update. This does not work, however and throws the following:

The inner -23 dimensions of output.shape=[1,32,32,155] must match the inner 1 dimensions of updates.shape=[1,32,32,27]: Shapes must be equal rank, but are 0 and 1

Am I understanding it wrong? Why does it not work? How would one update a ND-tensor?

Thanks

Use slice and concat :

feat = tf.random.uniform([1, 32, 32, 155])
updates = tf.zeros([1, 32, 32, 27])
result = tf.concat((feat[:,:,:,27:], updates), -1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM