简体   繁体   中英

How to pad a irregular shape tensor in tensorflow (pure TF way)?

test_tensor = [[2], [1, 2, 3], [4, 5]] # irregular shape
# dose there have a tf (better, and faster?) function to pad this tensor to a dense tensor with a defult value?
# like this: test_tensor ==> dense tensor:[[2, -1, -1],[1, 2, 3], [4, 5, -1]]

Ps. Please do not use pure python & numpy

Because I need add this operation into my TF model graph, so maybe need complete the operation with pure TF way?

I am assuming that the type of your starting tensor is SparseTensor. (I don't think it is possible to can have a dense tensor with an 'irregular shape'. If the type is not even a tensor, then you don't need a "pure TF way")

Use the following:

    dense = tf.sparse_tensor_to_dense(sparse_tensor_input, default_value=-1)

https://www.tensorflow.org/api_docs/python/tf/sparse_tensor_to_dense

If you want the shape of your dense tensor to be different from the input sparse tensor, you can change the shape of the latter before calling this function, or use a lower level function https://www.tensorflow.org/api_docs/python/tf/sparse_to_dense .

Recently I met the same issue, and I fix it with a little trick (maybe someone has same problem, even though this post has been published almost 1 year).

Firstly, TensorFlow cannot convert this array into a dense tensor, so I concat the array into string list, likes a = ['2','1,2,3','4,5']

Then split this string by using pure TF code, here is simple code:

def pad_length(sequence, limited_len):
    seq_sparse = tf.string_split(sequence, ',')
    seq_dense = tf.sparse_to_dense(
        seq_sparse.indices, seq_sparse.dense_shape, tf.cast(tf.string_to_number(seq_sparse.values), tf.int32)
    )
    seq_slice = tf.strided_slice(seq_dense, [0, 0], [tf.shape(sequence)[0], limited_len])
    pad_dense = tf.pad(seq_slice, paddings=[[0, 0], [0, limited_len - tf.shape(seq_slice)[1]]])
    return pad_dense

a = ['2','1,2,3','4,5']
a = tf.convert_to_tensor(a)
b = pad_length(a, 3)

sess=tf.Ssssion()
sess.run(b)

"""
b => array([
   [2, 0, 0],
   [1, 2, 3],
   [4, 5, 0]
], dtype=int32)
"""

Cheers!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM