简体   繁体   English

3D SparseTensor 矩阵乘以 2D 张量:InvalidArgumentError:张量“a_shape”必须有 2 个元素 [Op:SparseTensorDenseMatMul]

[英]3D SparseTensor matrix multiplication with 2D Tensor :InvalidArgumentError: Tensor 'a_shape' must have 2 elements [Op:SparseTensorDenseMatMul]

Hi I am tryigng to do a 3D SparseTensor matrix multiplication with 2D Tensor.嗨,我正在尝试使用 2D 张量进行 3D SparseTensor 矩阵乘法。 Here is a toy example:这是一个玩具示例:

3D tensor matrix multiplication with 2D Tensor 3D 张量矩阵与二维张量相乘

import tensorflow as tf
import numpy as np

a = np.array([[[1., 0., 2., 0.],
              [3., 0., 0., 4.]]])
b = (np.array([1., 2.])[:,np.newaxis]).T

a_t = tf.constant(a)
b_t = tf.constant(b)
    
tf.matmul(b_t,a_t)

<tf.Tensor: shape=(1, 1, 4), dtype=float64, numpy=array([[[7., 0., 2., 8.]]])>

3D SparseTensor matrix multiplication with 2D Tensor 3D SparseTensor 矩阵乘法与二维张量

import tensorflow as tf
import numpy as np

a = np.array([[[1., 0., 2., 0.],
              [3., 0., 0., 4.]]])
b = (np.array([1., 2.])[:,np.newaxis]).T

a_t = tf.constant(a)
b_t = tf.constant(b)

a_s = tf.sparse.from_dense(a_t)

tf.sparse.sparse_dense_matmul(b_t,a_s)

InvalidArgumentError: Tensor 'a_shape' must have 2 elements [Op:SparseTensorDenseMatMul]

Can you please help me to sort out this error?你能帮我解决这个错误吗?

The documentation says that the rank of the dense matrix must be 2. 文档说密集矩阵的秩必须是 2。

Multiply SparseTensor (or dense Matrix) (of rank 2) "A" by dense matrix将稀疏张量(或密集矩阵)(等级 2)“A”乘以密集矩阵

Te first argument of the call of this function is:此 function 调用的第一个参数是:

SparseTensor (or dense Matrix) A, of rank 2.稀疏张量(或稠密矩阵)A,秩为 2。

So you can do what you're trying to do, at least not this way.所以你可以做你想做的事,至少不是这样。 Seems like you don't really need the three dimensions, as one of them is just one.似乎您并不真的需要三个维度,因为其中之一只是一个。 If you squeeze it out, it will work:如果你把它挤出来,它会起作用:

a_t = tf.constant(tf.squeeze(a))
b_t = tf.constant(b)

a_s = tf.sparse.from_dense(a_t)

tf.sparse.sparse_dense_matmul(b_t,a_s)
<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[7., 0., 2., 8.]])>

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM