简体   繁体   中英

PyTorch: Row-wise Dot Product

Suppose I have two tensors:

a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

Where the third index is the index of a vector.

I want to take the dot product between each vector in b with respect to the vector in a .

To illustrate, this is what I mean:

dots = torch.Tensor(10, 1000, 6, 1)
for b in range(10):
     for c in range(1000):
           for v in range(6):
            dots[b,c,v] = torch.dot(b[b,c,v], a[b,c,0]) 

How would I achieve this using torch functions?

a = torch.randn(10, 1000, 1, 4)
b = torch.randn(10, 1000, 6, 4)

c = torch.sum(a * b, dim=-1)

print(c.shape)

torch.Size([10, 1000, 6])

c = c.unsqueeze(-1)
print(c.shape)

torch.Size([10, 1000, 6, 1])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM