简体   繁体   中英

TensorFlow: how to batch mut-mul a batch tensor by a weight variable?

I have the following batch shape:

 [?,227,227]

And the following weight variable:

 weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})

 weight_var = tf.Variable(weight_tensor)

But when I do tf.batch_matmul :

 matrix = tf.batch_matmul(prev_net_2d,weight_var)

I fail with the following error:

ValueError: Shapes (?,) and () must have the same rank


So my question becomes: How do I do this?

How do I just have a weight_variable in 2D that gets multiplied by each individual picture (227x227) so that I have a (227x227) output?? The flat version of this operation completely exhausts the resources...plus the gradient won't change the weights correctly in the flat form...


Alternatively: how do I split the incoming tensor along the batch dimension (?,) so that I can run the tf.matmul function on each of the split tensors with my weight_variable?

You could tile weights along the first dimension

weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})
weight_var = tf.Variable(weight_tensor)
weight_var_batch = tf.tile(tf.expand_dims(weight_var, axis=0), [batch_size, 1, 1])
matrix = tf.matmul(prev_net_2d,weight_var_batch)

Although batch_matmul doesn't exist anymore

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM