[英]using tf.where() to select 3d tensor by 2d conditions & replacing elements in a 2d indices with keys and values
标题中有2个问题。 我对这两个问题感到困惑,因为tensorflow是一种静态编程语言(我真的很想回到pytorch或chainer)。
我举两个例子。 请以张量流代码或提供相关功能链接回答我。
1)tf.where()
data0 = tf.zeros([2, 3, 4], dtype = tf.float32)
data1 = tf.ones([2, 3, 4], dtype = tf.float32)
cond = tf.constant([[0, 1, 1], [1, 0, 0]])
# cond.shape == (2, 3)
# tf.where() works for 1d condition with 2d data,
# but not for 2d indices with 3d tensor
# currently, what I am doing is:
# cond = tf.stack([cond] * 4, 2)
data = tf.where(cond > 0, data1, data0)
# data should be [[0., 1., 1.], [1., 0., 0.]]
(我不知道如何将cond广播到3d张量)
2)改变二维张量中的元素
# all dtype == tf.int64
t2d = tf.Variable([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# TODO: change values at positions k to v
# I cannot do [t2d.copy()[i] = j for i, j in k, v]
t3d == [[[0, 1, -2], [3, 4, 5]],
[[0, 1, 2], [-3, 4, 5]]]
提前非常感谢您。 XD
这是两个截然不同的问题,应该以同样的方式发布,但是无论如何。
1)
是的,如果它们不同,则需要将所有输入手动广播到[ tf.where
]( https://www.tensorflow.org/api_docs/python/tf/where] 。为了有价值,有一个(旧)的公开问题 ,但到目前为止尚未实现隐式广播。您可以tf.stack
您的建议使用tf.stack
,尽管tf.tile
可能会更明显(并且可以节省内存,尽管我不确定如何确实实现):
cond = tf.tile(tf.expand_dims(cond, -1), (1, 1, 4))
或简单地使用tf.broadcast_to
:
cond = tf.broadcast_to(tf.expand_dims(cond, -1), tf.shape(data1))
2)
这是一种方法:
import tensorflow as tf
t2d = tf.constant([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# Tile t2d
n = tf.shape(k)[0]
t2d_tile = tf.tile(tf.expand_dims(t2d, 0), (n, 1, 1))
# Add aditional coordinate to index
idx = tf.concat([tf.expand_dims(tf.range(n), 1), k], axis=1)
# Make updates tensor
s = tf.shape(t2d_tile)
t2d_upd = tf.scatter_nd(idx, v, s)
# Make updates mask
upd_mask = tf.scatter_nd(idx, tf.ones_like(v, dtype=tf.bool), s)
# Make final tensor
t3d = tf.where(upd_mask, t2d_upd, t2d_tile)
# Test
with tf.Session() as sess:
print(sess.run(t3d))
输出:
[[[ 0 1 -2]
[ 3 4 5]]
[[ 0 1 2]
[-3 4 5]]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.