简体   繁体   English

使用 numpy 计算 3D 矩阵乘法的有效方法

[英]Efficient way to calculate 3D matrix multiplication using numpy

How can I efficiently write and calculate this multiplication using numpy:如何使用 numpy 有效地编写和计算这个乘法:

 for k in range(K):
    for i in range(SIZE):
       for j in range(SIZE):
          for i_b in range(B_SIZE):
             for j_b in range(B_SIZE):
                for k_b in range(k+1):
                   data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * arr2[k_b, i, j]

For example:例如:

SIZE, B_SIZE = 32, 8
arr1.shape -> (8, 8, 8)
arr2.shape -> (8, 32, 32)
data.shape -> (K, 256, 256)

Thank you.谢谢你。

You can use Numba for such kind of non-trivial case and rework the loops to use efficiently the CPU cache .您可以将Numba用于这种非平凡的情况,并重新设计循环以有效地使用 CPU缓存 Here is an example:这是一个例子:

import numba as nb

@nb.njit
def compute(data, arr1, arr2):
    for k in range(K):
        for k_b in range(k+1):
            for i in range(SIZE):
                for j in range(SIZE):
                    tmp = arr2[k_b, i, j]
                    for i_b in range(B_SIZE):
                        for j_b in range(B_SIZE):
                            data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * tmp

If you do this operation once, then you can pre-compile the Numba code by providing the types of the arrays.如果您执行此操作一次,则可以通过提供 arrays 的类型来预编译Numba 代码。 If K is big, then you can parallelize the code using @nb.njit(parallel=True) and use for k in nb.prange(K) rather than for k in range(K) .如果K很大,那么您可以使用@nb.njit(parallel=True)并行化代码并使用for k in nb.prange(K)而不是for k in range(K) This should be several order of magnitude fater.这应该是几个数量级的脂肪。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM