简体   繁体   中英

Matmul with different rank

I have 3 tensor
X shape (1, c, h, w) , assume (1, 20, 40, 50)
Fx shape (num, w, N) , assume (1000, 50, 10)
Fy shape (num, N, h) , assume (1000, 10, 40)

What I want to do is Fy * (X * Fx) ( * means matmul )
X * Fx shape (num, c, h, N) , assume (1000, 20, 40, 10)
Fy * (X * Fx) shape (num, c, N, N) , assume (1000, 20, 10, 10)

I am using tf.tile and tf.expand_dims to do it
but I think it use a lot of memory( tile copy data right?), and slow
try to find better way that faster and use small memory to accomplish

# X: (1, c, h, w)
# Fx: (num, w, N)
# Fy: (num, N, h)

X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1])  # (num, c, h, w)
Fx_ex = tf.expand_dims(Fx, axis=1)  # (num, 1, w, N)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1])  # (num, c, w, N)
tmp = tf.matmul(X, Fxt_ex)  # (num, c, h, N)

Fy_ex = tf.expand_dims(Fy, axis=1)  # (num, 1, N, h)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1])  # (num, c, N, h)
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N)

A case for the mythical einsum , I guess:

>>> import numpy as np
>>> X = np.random.rand(1, 20, 40, 50)
>>> Fx = np.random.rand(100, 50, 10)
>>> Fy = np.random.rand(100, 10, 40)
>>> np.einsum('nMh,uchw,nwN->ncMN', Fy, X, Fx).shape
(100, 20, 10, 10)

It's should work almost the same in tf as in numpy (using uppercase indices isn't allowed in some tf versions, I saw). Although this admittedly exceeds a regex in unreadability if you've never seen the notation before.

For otherone may interested
I think the answer of @phg maybe work
But in my case num h w are dynamic, ie None
So tf.einsum in tensorflow r1.0 will raise error, since there are more than one None shape in one tensor

fortunately, there is a issue and pull request
seems can handle situation that there are more than one None shape
Need to build from source(master branch)
I will report the result after I re-build tensorflow

BTW, in tf.einsum only accept lowercase

Report
Yes, The newest version of tensorflow (master branch) accept dynamic shape for tf.einsum
and it is huge speed improvement after using tf.einsum , really awesome

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM