简体   繁体   English

使用Python中的tensorflow按重量填充张量

[英]Pad a tensor by a weight with tensorflow in Python

I have a 3d tensor A of shape [n, ?, m] having one non-zero element along the third axis. 我有一个3d张量A的形状[n, ?, m]沿第三轴有一个非零元素。 For example 例如

A[0,0,:] = [0,0,0,1,0,0]
A[1,0,:] = [0,0,1,0,0,0]
A[0,1,:] = [0,1,0,0,0,0]
A[1,1,:] = [1,0,0,0,0,0]

I have a weight tensor w of shape (1,) . 我的重张量w形状的(1,)

I want to dilate the tensor A by the weight w such that I can transform the tensor A as below 我想通过权重w扩张张量A ,以便我可以如下变换张量A

A[0,0,:] = [0,0,w,1,w,0]
A[1,0,:] = [0,w,1,w,0,0]
A[0,1,:] = [w,1,w,0,0,0]
A[1,1,:] = [1,w,0,0,0,w]

Please note that the weight w is added adjacent to the nonzero element 1 and if it is at the border then we wrap the indexes around. 请注意,权重w是在非零元素1附近添加的,如果它在边界处,则我们将索引包裹起来。

How can I do that using tensorflow in python. 我怎么能在python中使用tensorflow来做到这tensorflow

EDIT: 编辑:

Here is a more general version that works for a padding vector with more than one element: 这是一个更通用的版本,适用于具有多个元素的填充向量:

import tensorflow as tf

def surround_nonzero(a, w):
    # Find non-zero positions
    idx = tf.where(tf.not_equal(a, 0))
    # A vector to shift the last value in the indices
    w_len = tf.shape(w, out_type=tf.int64)[0]
    shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
    shift_len = shift1 * tf.expand_dims(tf.range(1, w_len + 1), 1)
    # Shift last value of indices using module to wrap around
    a_shape = tf.shape(a, out_type=tf.int64)
    d = a_shape[-1]
    idx_exp = tf.expand_dims(idx, 1)
    idx_prev_exp = (idx_exp - shift_len) % d
    idx_next_exp = (idx_exp + shift_len) % d
    # Reshape shifted indices
    a_rank = tf.rank(a)
    idx_prev = tf.reshape(idx_prev_exp, [-1, a_rank])
    idx_next = tf.reshape(idx_next_exp, [-1, a_rank])
    # Take non-zero values
    nonzero = tf.gather_nd(a, idx)
    # Tile wrapping value twice the number of non-zero values
    n = tf.shape(nonzero)[0]
    w2n = tf.tile(w, [2 * n])
    # Make full index and values for scattering with non-zero values and wrapping value
    idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
    values_full = tf.concat([nonzero, w2n], axis=0)
    # Make output tensor with scattering
    return tf.scatter_nd(idx_full, values_full, a_shape)

# Test
with tf.Graph().as_default():
    A = tf.constant([[[0, 0, 0, 0, 0, 1, 0, 0],
                      [0, 0, 1, 0, 0, 0, 0, 0]],
                     [[0, 0, 0, 0, 1, 0, 0, 0],
                      [1, 0, 0, 0, 0, 0, 0, 0]]],
                    dtype=tf.int32)
    w = tf.constant([2, 3, 4], dtype=tf.int32)
    out = surround_nonzero(A, w)
    with tf.Session() as sess:
        print(sess.run(out))

Output: 输出:

[[[4 0 4 3 2 1 2 3]
  [3 2 1 2 3 4 0 4]]

 [[0 4 3 2 1 2 3 4]
  [1 2 3 4 0 4 3 2]]]

As before, this assumes that the padding always "fits", and the behavior in cases where padding values would overlap is not guaranteed. 和以前一样,这假设填充总是“适合”,并且不保证填充值重叠的情况下的行为。


Here is a way to do that using tf.scatter_nd : 这是使用tf.scatter_nd执行此操作的tf.scatter_nd

import tensorflow as tf

def surround_nonzero(a, w):
    # Find non-zero positions
    idx = tf.where(tf.not_equal(a, 0))
    # A vector to shift the last value in the indices by one
    shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
    # Shift last value of indices using module to wrap around
    a_shape = tf.shape(a, out_type=tf.int64)
    d = a_shape[-1]
    idx_prev = (idx - shift1) % d
    idx_next = (idx + shift1) % d
    # Take non-zero values
    nonzero = tf.gather_nd(a, idx)
    # Tile wrapping value twice the number of non-zero values
    n = tf.shape(nonzero)[0]
    w2n = tf.tile(w, [2 * n])
    # Make full index and values for scattering with non-zero values and wrapping value
    idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
    values_full = tf.concat([nonzero, w2n], axis=0)
    # Make output tensor with scattering
    return tf.scatter_nd(idx_full, values_full, a_shape)

# Test
with tf.Graph().as_default():
    A = tf.constant([[[0, 0, 0, 1, 0, 0],
                      [0, 1, 0, 0, 0, 0]],
                     [[0, 0, 1, 0, 0, 0],
                      [1, 0, 0, 0, 0, 0]]],
                    dtype=tf.int32)
    w = tf.constant([2], dtype=tf.int32)
    out = surround_nonzero(A, w)
    with tf.Session() as sess:
        print(sess.run(out))

Output: 输出:

[[[0 0 2 1 2 0]
  [2 1 2 0 0 0]]

 [[0 2 1 2 0 0]
  [1 2 0 0 0 2]]]

Note this assumes each non-zero value is surrounded by zeros (as is your case). 请注意,这假设每个非零值都被零包围(就像您的情况一样)。 Otherwise, the scatter operation would find duplicate indices and the output would not be deterministic. 否则,分散操作将找到重复的索引,并且输出将不是确定性的。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM