简体   繁体   中英

Pad a tensor by a weight with tensorflow in Python

I have a 3d tensor A of shape [n, ?, m] having one non-zero element along the third axis. For example

A[0,0,:] = [0,0,0,1,0,0]
A[1,0,:] = [0,0,1,0,0,0]
A[0,1,:] = [0,1,0,0,0,0]
A[1,1,:] = [1,0,0,0,0,0]

I have a weight tensor w of shape (1,) .

I want to dilate the tensor A by the weight w such that I can transform the tensor A as below

A[0,0,:] = [0,0,w,1,w,0]
A[1,0,:] = [0,w,1,w,0,0]
A[0,1,:] = [w,1,w,0,0,0]
A[1,1,:] = [1,w,0,0,0,w]

Please note that the weight w is added adjacent to the nonzero element 1 and if it is at the border then we wrap the indexes around.

How can I do that using tensorflow in python.

EDIT:

Here is a more general version that works for a padding vector with more than one element:

import tensorflow as tf

def surround_nonzero(a, w):
    # Find non-zero positions
    idx = tf.where(tf.not_equal(a, 0))
    # A vector to shift the last value in the indices
    w_len = tf.shape(w, out_type=tf.int64)[0]
    shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
    shift_len = shift1 * tf.expand_dims(tf.range(1, w_len + 1), 1)
    # Shift last value of indices using module to wrap around
    a_shape = tf.shape(a, out_type=tf.int64)
    d = a_shape[-1]
    idx_exp = tf.expand_dims(idx, 1)
    idx_prev_exp = (idx_exp - shift_len) % d
    idx_next_exp = (idx_exp + shift_len) % d
    # Reshape shifted indices
    a_rank = tf.rank(a)
    idx_prev = tf.reshape(idx_prev_exp, [-1, a_rank])
    idx_next = tf.reshape(idx_next_exp, [-1, a_rank])
    # Take non-zero values
    nonzero = tf.gather_nd(a, idx)
    # Tile wrapping value twice the number of non-zero values
    n = tf.shape(nonzero)[0]
    w2n = tf.tile(w, [2 * n])
    # Make full index and values for scattering with non-zero values and wrapping value
    idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
    values_full = tf.concat([nonzero, w2n], axis=0)
    # Make output tensor with scattering
    return tf.scatter_nd(idx_full, values_full, a_shape)

# Test
with tf.Graph().as_default():
    A = tf.constant([[[0, 0, 0, 0, 0, 1, 0, 0],
                      [0, 0, 1, 0, 0, 0, 0, 0]],
                     [[0, 0, 0, 0, 1, 0, 0, 0],
                      [1, 0, 0, 0, 0, 0, 0, 0]]],
                    dtype=tf.int32)
    w = tf.constant([2, 3, 4], dtype=tf.int32)
    out = surround_nonzero(A, w)
    with tf.Session() as sess:
        print(sess.run(out))

Output:

[[[4 0 4 3 2 1 2 3]
  [3 2 1 2 3 4 0 4]]

 [[0 4 3 2 1 2 3 4]
  [1 2 3 4 0 4 3 2]]]

As before, this assumes that the padding always "fits", and the behavior in cases where padding values would overlap is not guaranteed.


Here is a way to do that using tf.scatter_nd :

import tensorflow as tf

def surround_nonzero(a, w):
    # Find non-zero positions
    idx = tf.where(tf.not_equal(a, 0))
    # A vector to shift the last value in the indices by one
    shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
    # Shift last value of indices using module to wrap around
    a_shape = tf.shape(a, out_type=tf.int64)
    d = a_shape[-1]
    idx_prev = (idx - shift1) % d
    idx_next = (idx + shift1) % d
    # Take non-zero values
    nonzero = tf.gather_nd(a, idx)
    # Tile wrapping value twice the number of non-zero values
    n = tf.shape(nonzero)[0]
    w2n = tf.tile(w, [2 * n])
    # Make full index and values for scattering with non-zero values and wrapping value
    idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
    values_full = tf.concat([nonzero, w2n], axis=0)
    # Make output tensor with scattering
    return tf.scatter_nd(idx_full, values_full, a_shape)

# Test
with tf.Graph().as_default():
    A = tf.constant([[[0, 0, 0, 1, 0, 0],
                      [0, 1, 0, 0, 0, 0]],
                     [[0, 0, 1, 0, 0, 0],
                      [1, 0, 0, 0, 0, 0]]],
                    dtype=tf.int32)
    w = tf.constant([2], dtype=tf.int32)
    out = surround_nonzero(A, w)
    with tf.Session() as sess:
        print(sess.run(out))

Output:

[[[0 0 2 1 2 0]
  [2 1 2 0 0 0]]

 [[0 2 1 2 0 0]
  [1 2 0 0 0 2]]]

Note this assumes each non-zero value is surrounded by zeros (as is your case). Otherwise, the scatter operation would find duplicate indices and the output would not be deterministic.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM