简体   繁体   English

使用tf.where()通过2d条件选择3d张量并用键和值替换2d索引中的元素

[英]using tf.where() to select 3d tensor by 2d conditions & replacing elements in a 2d indices with keys and values

There are 2 questions in the title. 标题中有2个问题。 I am confused by both questions because tensorflow is such a static programming language (I really want to go back to either pytorch or chainer). 我对这两个问题感到困惑,因为tensorflow是一种静态编程语言(我真的很想回到pytorch或chainer)。

I give 2 examples. 我举两个例子。 please answer me in tensorflow codes or providing the relevant function links. 请以张量流代码或提供相关功能链接回答我。

1) tf.where() 1)tf.where()

data0 = tf.zeros([2, 3, 4], dtype = tf.float32)
data1 = tf.ones([2, 3, 4], dtype = tf.float32)
cond = tf.constant([[0, 1, 1], [1, 0, 0]])
# cond.shape == (2, 3)
# tf.where() works for 1d condition with 2d data, 
# but not for 2d indices with 3d tensor
# currently, what I am doing is:
#    cond = tf.stack([cond] * 4, 2)
data = tf.where(cond > 0, data1, data0)
# data should be [[0., 1., 1.], [1., 0., 0.]]

(I don't know how to broadcast cond to 3d tensor) (我不知道如何将cond广播到3d张量)

2) change element in 2d tensor 2)改变二维张量中的元素

# all dtype == tf.int64
t2d = tf.Variable([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# TODO: change values at positions k to v
# I cannot do [t2d.copy()[i] = j for i, j in k, v]
t3d == [[[0, 1, -2], [3, 4, 5]],
        [[0, 1, 2], [-3, 4, 5]]]

Thank you so much in advance. 提前非常感谢您。 XD XD

This are two quite different questions, and they should probably have been posted as such, but anyway. 这是两个截然不同的问题,应该以同样的方式发布,但是无论如何。

1) 1)

Yes, you need to manually broadcast all the inputs to [ tf.where ]( https://www.tensorflow.org/api_docs/python/tf/where] if they are different. For what is worth, there is an (old) open issue about it , but so far implicit broadcasting it has not been implemented. You can use tf.stack like you suggest, although tf.tile would probably be more obvious (and may save memory, although I'm not sure how it is implemented really): 是的,如果它们不同,则需要将所有输入手动广播到[ tf.where ]( https://www.tensorflow.org/api_docs/python/tf/where] 。为了有价值,有一个(旧)的公开问题 ,但到目前为止尚未实现隐式广播。您可以tf.stack您的建议使用tf.stack ,尽管tf.tile可能会更明显(并且可以节省内存,尽管我不确定如何确实实现):

cond = tf.tile(tf.expand_dims(cond, -1), (1, 1, 4))

Or simply with tf.broadcast_to : 或简单地使用tf.broadcast_to

cond = tf.broadcast_to(tf.expand_dims(cond, -1), tf.shape(data1))

2) 2)

This is one way to do that: 这是一种方法:

import tensorflow as tf

t2d = tf.constant([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# Tile t2d
n = tf.shape(k)[0]
t2d_tile = tf.tile(tf.expand_dims(t2d, 0), (n, 1, 1))
# Add aditional coordinate to index
idx = tf.concat([tf.expand_dims(tf.range(n), 1), k], axis=1)
# Make updates tensor
s = tf.shape(t2d_tile)
t2d_upd = tf.scatter_nd(idx, v, s)
# Make updates mask
upd_mask = tf.scatter_nd(idx, tf.ones_like(v, dtype=tf.bool), s)
# Make final tensor
t3d = tf.where(upd_mask, t2d_upd, t2d_tile)
# Test
with tf.Session() as sess:
    print(sess.run(t3d))

Output: 输出:

[[[ 0  1 -2]
  [ 3  4  5]]

 [[ 0  1  2]
  [-3  4  5]]]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM