简体   繁体   中英

Multiplying a dense vector by a sparse matrix in tensorflow

Is there an easy way to multiply a sparse matrix and a dense tensor in tensorflow? I tried

def sparse_mult(sparse_mat,dense_vec):
    vec = tf.zeros(dense_vec.shape, dense_vec.dtype)
    indices = sparse_mat.indices
    values = sparse_mat.values
    with tf.Session() as sess:
        num_vals = sess.run(tf.size(values))
    for i in range(num_vals):
        vec[indices[i,0]] += values[i] * dense_vec[indices[i,1]]
    return vec

But I get "TypeError: 'Tensor' object does not support item assignment." I tried

def sparse_mult(sparse_mat,dense_vec):
    vec = tf.zeros(dense_vec.shape, dense_vec.dtype)
    indices = sparse_mat.indices
    values = sparse_mat.values
    with tf.Session() as sess:
        num_vals = sess.run(tf.size(values))
    for i in range(num_vals):
        vec = vec[indices[i,0]].assign(vec[indices[i,0]] + values[i] * dense_vec[indices[i,1]])
    return vec

and got "ValueError: Sliced assignment is only supported for variables." Turning vec into a variable with vec = tf.get_variable('vec', initializer = tf.zeros(dense_vec.shape, dense_vec.dtype)) gives the same error. Is there a not too memory intensive way of doing this?

You should use tf.sparse_tensor_dense_matmul() which was invented for this exact purpose. You don't need to create your own function. This code (tested):

import tensorflow as tf

a = tf.SparseTensor( indices = [ [ 0, 0 ], [ 3, 4 ] ],
                     values = tf.constant( [ 1.0, 10 ] ),
                     dense_shape = [ 5, 5 ] )
vec = tf.constant( [ [ 2.0 ], [ 3 ], [ 4 ], [ 5 ], [ 6 ] ] )
c = tf.sparse_tensor_dense_matmul( a, vec )

with tf.Session() as sess:
    res = sess.run( c )
    print( res )

will output:

[[ 2.]
[ 0.]
[ 0.]
[60.]
[ 0.]]


For reference, my first answer was referring to tf.sparse_matmul() , which somewhat confusingly multiplies two dense matrices but with a specially designed algorithm for matrices with many zero values . It will choke on sparse_tensor arguments.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM