繁体   English   中英

来自密集张量 Tensorflow 的稀疏张量(矩阵)

[英]Sparse Tensor (matrix) from a dense Tensor Tensorflow

我正在创建一个卷积稀疏自动编码器,我需要将一个充满值的 4D 矩阵(其形状为[samples, N, N, D] )转换为一个稀疏矩阵。

对于每个样本,我有 D NxN 个特征图。 我想将每个 NxN 特征映射转换为稀疏矩阵,最大值映射为 1,所有其他映射为 0。

我不想在运行时执行此操作,而是在 Graph 声明期间执行此操作(因为我需要使用生成的稀疏矩阵作为其他图形操作的输入),但我不明白如何获取索引以构建稀疏矩阵。

您可以使用tf.wheretf.gather_nd来做到这一点:

import numpy as np
import tensorflow as tf

# Make a tensor from a constant
a = np.reshape(np.arange(24), (3, 4, 2))
a_t = tf.constant(a)
# Find indices where the tensor is not zero
idx = tf.where(tf.not_equal(a_t, 0))
# Make the sparse tensor
# Use tf.shape(a_t, out_type=tf.int64) instead of a_t.get_shape()
# if tensor shape is dynamic
sparse = tf.SparseTensor(idx, tf.gather_nd(a_t, idx), a_t.get_shape())
# Make a dense tensor back from the sparse one, only to check result is correct
dense = tf.sparse_tensor_to_dense(sparse)
# Check result
with tf.Session() as sess:
    b = sess.run(dense)
np.all(a == b)
>>> True

将密集 numpy 数组转换为 tf.SparseTensor 的简单代码:

def denseNDArrayToSparseTensor(arr):
  idx  = np.where(arr != 0.0)
  return tf.SparseTensor(np.vstack(idx).T, arr[idx], arr.shape)

tf.sparse.from_dense从 1.15 开始就有tf.sparse.from_dense 示例:

In [1]: import tensorflow as tf

In [2]: x = tf.eye(3) * 5

In [3]: x
Out[3]: 
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[5., 0., 0.],
       [0., 5., 0.],
       [0., 0., 5.]], dtype=float32)>

应用tf.sparse.from_dense

In [4]: y = tf.sparse.from_dense(x)

In [5]: y.values
Out[5]: <tf.Tensor: shape=(3,), dtype=float32, numpy=array([5., 5., 5.], dtype=float32)>

In [6]: y.indices
Out[6]: 
<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 1],
       [2, 2]])>

通过应用tf.sparse.to_dense验证身份:

In [7]: tf.sparse.to_dense(y) == x
Out[7]: 
<tf.Tensor: shape=(3, 3), dtype=bool, numpy=
array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]])>

注意有一个内置的功能在了contrib (采取

在 TF 2.3 Tensorflow Probability 中有一个函数

import tensorflow_probability as tfp

tfp.math.dense_to_sparse(x, ignore_value=None, name=None)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM