简体   繁体   中英

How to compute weighted average of tensor A along an axis with weights specified by tensor B in tensorflow?

I am trying to apply a weighted average scheme on RNN output.
RNN output is represented by tensor A having dimension (a,b,c) .
I can simply take tf.reduce_mean(A,axis=1) to get the tensor C having dimension (a,c) .

However, I want to do the "weighted average" of tensor A along axis = 1 .
Weights are specified in the matrix B having dimension (d,b) .

For d = 1 , I can do tf.tensordot(A,B,[1,1]) to get the result of dimension (a,c) .
Now for d=a , I am unable to compute the weighted average.

Can someone suggest a solution?

I don't quite get why B should have dimensions (d,b) . If B contains the weights to do a weighted average of A across only one dimension, B only has to be a vector (b,) , not a matrix.

If B is a vector, you can do:

C = tf.tensordot(A,B,[1,0]) to get a vector C of shape (a,c ) which contains the weighted average of A across axis=1 using the weights specified in B .

Update:

You can do something like:

A = A*B[:,:,None] 

which is doing element wise multiplication of A and B , where B stores the weights given to each element in A . Then:

C = tf.reduce_mean(A,axis=1)

will do the weighted average since each element in A has been multiplied by its weight.

Since B is already normalized, the answer is

tf.reduce_sum(A * B[:, :, None], axis=1)

Indexing with None adds a new dimension, a behavior inherited from numpy. B[:,:, None] adds a last dimension so the result has shape (a, b, 1) . You can achieve the same thing with tf.expand_dims , whose name may make more sense to you.

A has shape (a, b, c) while B[:, :, None] has shape (a, b, 1) . When they are multiplied, expanded B will be treated as having shape (a, b, c) too, with the last dimension being c copies of the same value. This is called broadcasting .

Because of how broadcasting works, the same answer also works if B has shape (1, b) .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM