簡體   English   中英

TensorFlow(或Numpy)中的高級廣播

[英]Advanced broadcasting in TensorFlow (or Numpy)

在TensorFlow我有個秩為2張量M形狀的(矩陣) [D, D]和秩-3張量T形狀的[D, D, D]

我需要將它們組合成一個新的矩陣R ,如下所示:元素R[a, b+ca]由所有元素T[a, b, c]*M[b, c]的總和給出,其中b+ca是常數(其中b+ca必須在0到D-1 )。

創建R一種無效方法是在索引上嵌套for循環,並檢查b+ca是否不超過D-1 (例如numpy):

R = np.zeros([D,D])

for a in range(D):
  for b in range(D):
    for c in range(D):
      if 0 <= b+c-a < D:
        R[a, b+c-a] += T[a, b, c]*M[b, c]

但我想使用廣播和/或其他更有效的方法。

我該如何實現?

您可以向量化該計算,如下所示:

import numpy as np

np.random.seed(0)
D = 10
M = np.random.rand(D, D)
T = np.random.rand(D, D, D)
# Original calculation
R = np.zeros([D, D])
for a in range(D):
    for b in range(D):
        for c in range(D):
            if 0 <= b + c - a < D:
                R[a, b + c - a] += T[a, b, c] * M[b, c]
# Vectorized calculation
tm = T * M
a = np.arange(D)[:, np.newaxis, np.newaxis]
b, c = np.ogrid[:D, :D]
col_idx = b + c - a
m = (col_idx >= 0) & (col_idx < D)
row_idx = np.tile(a, [1, D, D])
R2 = np.zeros([D, D])
np.add.at(R2, (row_idx[m], col_idx[m]), tm[m])
# Check result
print(np.allclose(R, R2))
# True

另外,您可以考慮使用Numba加速循環:

import numpy as np
import numba as nb

@nb.njit
def calculation_nb(T, M, D):
    tm = T * M
    R = np.zeros((D, D), dtype=tm.dtype)
    for a in nb.prange(D):
      for b in range(D):
        for c in range(max(a - b, 0), min(D + a - b, D)):
          R[a, b + c - a] += tm[a, b, c]
    return R

print(np.allclose(R, calculation_nb(T, M, D)))
# True

在一些快速測試中,即使沒有並行化,這也比NumPy快得多。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM