[英]tensordot on 3-dim x 2-dim
我有tensor_A
具有形狀(batch_size, x_1, x_2)
和tensor_B
具有形狀(x_2, x_3)
我想將tensor_A
每個元素與tensor_B
點乘。 一個不使用 tensordot 的例子是這樣的:
product_tensor = np.zeros((batch_size, x_1, x_3))
for i in range(batch_size):
product_tensor[i] = np.dot(tensor_A[i], tensor_B)
我無法弄清楚axes
參數的參數應該是什么。 根據我的閱讀, axes=1
表示點積,但我不知道它是將 A 的前 2 個軸與 B 相乘還是將 A 的后 2 個軸與 B 相乘。
我嘗試過tf.tensordot(tensor_A, tensor_B[None, :, :, :], axes=1)
沒有成功,因為它似乎將tensor_A
重塑為形狀(batch_size * x_1, x_2)
和tensor_B
形狀(1, x_2 * x_3
。
幫助將不勝感激!
這應該會給你想要的結果:
import numpy as np
a = np.array([
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
])
b = np.array([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
print('a.shape = ', a.shape)
print('b.shape = ', b.shape)
# tensordot
c_tensordot = np.tensordot(a, b, axes=(1))
# loop method with dot
c_loop = np.empty([a.shape[0], a.shape[1], b.shape[1]])
for i in range(0,a.shape[0]):
c_loop[i] = np.dot(a[i], b)
print('c_tensordot = ', c_tensordot)
print('c_loop = ', c_loop)
print('c_tensordot.shape = ', c_tensordot.shape)
print('c_loop.shape = ', c_loop.shape)
print('\nAre arrays equal: ', np.array_equal(c_tensordot, c_loop))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.