简体   繁体   中英

Tensorflow sparse tensor row-wise mask

I have a tf sparse tensor, and a mask for rows of the tensor's dense form.

Eg

tf.SparseTensor(
    indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]],
    values=tf.constant([3, 40, 9, 5, 6, 4], tf.int64),
    dense_shape=[2, 10]
)

I have a mask [ True, False] and I would like to keep the 1st row as in the dense form, without converting the sparse tensor:

tf.SparseTensor(
    indices=[[0, 0], [0, 1], [0, 2]],
    values=tf.constant([3, 40, 9], tf.int64),
    dense_shape=[1, 10]
)

How can I use the mask to directly filter the sparse tensor?

Maybe try using a boolean_mask and tf.sparse.slice :

import tensorflow as tf

x = tf.SparseTensor(
    indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]],
    values=tf.constant([3, 40, 9, 5, 6, 4], tf.int64),
    dense_shape=[2, 10])

mask = tf.constant([True, False])
indices = x.indices
start = tf.boolean_mask(indices, tf.where(tf.greater(indices[:, 0], 0), mask[1], mask[0]), axis=0)
result = tf.sparse.slice(x, start = start[0,:], size = tf.cast(tf.shape(tf.boolean_mask(tf.zeros(x.dense_shape), mask, axis=0)), dtype=tf.int64))
print(result)
SparseTensor(indices=tf.Tensor(
[[0 0]
 [0 1]
 [0 2]], shape=(3, 2), dtype=int64), values=tf.Tensor([ 3 40  9], shape=(3,), dtype=int64), dense_shape=tf.Tensor([ 1 10], shape=(2,), dtype=int64))

Or with tf.sparse.retain , you can use the boolean mask directly on your sparse tensor:

import tensorflow as tf

x = tf.SparseTensor(
    indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]],
    values=tf.constant([3, 40, 9, 5, 6, 4], tf.int64),
    dense_shape=[2, 10])

mask = tf.constant([True, False])
mask = tf.where(tf.greater(x.indices[:, 0], 0), mask[1], mask[0])
result = tf.sparse.retain(x, mask)
print(result)

However, this approach does not change the dense_shape :

SparseTensor(indices=tf.Tensor(
[[0 0]
 [0 1]
 [0 2]], shape=(3, 2), dtype=int64), values=tf.Tensor([ 3 40  9], shape=(3,), dtype=int64), dense_shape=tf.Tensor([ 2 10], shape=(2,), dtype=int64))

But you can use tf.sparse.reduce_sum on result to get the sparse tensor with the correct dense_shape :

tf.sparse.reduce_sum(result, axis=0, keepdims=True, output_is_sparse=True)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM