簡體   English   中英

3-dim x 2-dim 上的張量點

[英]tensordot on 3-dim x 2-dim

我有tensor_A具有形狀(batch_size, x_1, x_2)tensor_B具有形狀(x_2, x_3) 我想將tensor_A每個元素與tensor_B點乘。 一個不使用 tensordot 的例子是這樣的:

product_tensor = np.zeros((batch_size, x_1, x_3))
for i in range(batch_size):
    product_tensor[i] = np.dot(tensor_A[i], tensor_B)

我無法弄清楚axes參數的參數應該是什么。 根據我的閱讀, axes=1表示點積,但我不知道它是將 A 的前 2 個軸與 B 相乘還是將 A 的后 2 個軸與 B 相乘。

我嘗試過tf.tensordot(tensor_A, tensor_B[None, :, :, :], axes=1)沒有成功,因為它似乎將tensor_A重塑為形狀(batch_size * x_1, x_2)tensor_B形狀(1, x_2 * x_3

幫助將不勝感激!

這應該會給你想要的結果:

import numpy  as np

a = np.array([
    [[1, 2, 3], [4, 5, 6]],
    [[1, 2, 3], [4, 5, 6]],
    [[1, 2, 3], [4, 5, 6]],
    [[1, 2, 3], [4, 5, 6]],
    [[1, 2, 3], [4, 5, 6]],
    [[1, 2, 3], [4, 5, 6]],
])
b = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
print('a.shape = ', a.shape)
print('b.shape = ', b.shape)


# tensordot
c_tensordot = np.tensordot(a, b, axes=(1))

# loop method with dot
c_loop = np.empty([a.shape[0], a.shape[1], b.shape[1]])
for i in range(0,a.shape[0]):
    c_loop[i] = np.dot(a[i], b)


print('c_tensordot = ', c_tensordot)
print('c_loop      = ', c_loop)

print('c_tensordot.shape = ', c_tensordot.shape)
print('c_loop.shape      = ', c_loop.shape)

print('\nAre arrays equal: ', np.array_equal(c_tensordot, c_loop))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM