简体   繁体   中英

How to construct a Tensor with values at given indices inside graph mode?

I have a 1D tensor of length 128 with logits in it. For a custom loss, I'm trying to replace the 3 highest values with 1.0 and replace the rest with 0.0. This is inside a @tf.function, so I can't convert it to numpy and do the manipulation there.

I've come up with:

top_3 = tf.math.top_k(code, k=3)
indices = top_3.indices    
updates = tf.ones_like(indices)
new_code = tf.scatter_nd(indices, updates, tf.constant([128]))

But it gives me the error:

ValueError: Dimensions [3,1) of input[shape=[?]] = [] must match dimensions [0,1) of updates[shape=[3]] = [3]: Shapes must be equal rank, but are 0 and 1 for '{{node ScatterNd}} = ScatterNd[T=DT_INT32, Tindices=DT_INT32](TopKV2:1, ones_like_1, Const_3)' with input shapes: [3], [3], [1].

which I don't understand, because indices should have length 3, and so does updates. Whats the problem?

Try:

import tensorflow as tf

code = tf.random.normal((128,))
top_3 = tf.math.top_k(code, k=3)
indices = top_3.indices

updates = tf.ones_like(indices, dtype=tf.float32)
new_code = tf.zeros_like(code)
new_code = tf.tensor_scatter_nd_update(new_code, indices[..., None], updates)
print(new_code)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM