繁体   English   中英

使用 numpy 和 fft 在 for 循环中进行多核处理

[英]multi core processing in for loop using numpy and fft

我使用 numpy 和 fft 计算了向量。 我使用了 numpy 广播方法和 for 循环。 两种方法的速度相似。 如何使用多核和 numpy 和 fft 计算向量?

import numpy as np
from numpy.fft import fft, ifft

num_row, num_col = 6000, 13572

ss = np.ones((num_row, num_col), dtype=np.complex128)
sig = np.random.standard_normal(num_col) * 1j * np.random.standard_normal(num_col)

# for loop    
for idx in range(num_row):
    ss[idx, :] = ifft(fft(ss[idx, :]) * sig)

# broadcast
ss = ifft(fft(ss, axis=1) * sig, axis=1)

结果

loop : 10.798867464065552 sec
broadcast : 11.298897981643677 sec

您可以将axis参数用于fftifft ,以及广播:

ss = np.ones((num_row, num_col), dtype=np.complex128)
sig = np.random.standard_normal(num_col) * 1j * np.random.standard_normal(num_col)

ss = ifft(fft(ss, axis=1) * sig, axis=1)

我比较了广播、线程池和循环。 在这种情况下,ThreadPool 的性能最好。

# %% Import
# Standard library imports
import time
from multiprocessing.pool import ThreadPool

# Third party imports
from numpy import zeros, complex128, allclose
from numpy.fft import fft, ifft
from numpy.random import standard_normal


# %% Generate data
n_row, n_col = 6000, 13572

ss = standard_normal((n_row, n_col)) + 1j * standard_normal((n_row, n_col))
sig = standard_normal(n_col) + 1j * standard_normal(n_col)
ss_loop = zeros((n_row, n_col), dtype=complex128)
ss_thread = zeros((n_row, n_col), dtype=complex128)

# %% Loop processing
start_time = time.time()
for idx in range(n_row):
    ss_loop[idx, :] = ifft(fft(ss[idx, :]) * sig)
print(f'loop elapsed time : {time.time() - start_time}')

# %% Broadcast processing
start_time = time.time()
ss_broad = ifft(fft(ss, axis=1) * sig, axis=1)
print(f'broadcast elapsed time : {time.time() - start_time}')


# %% ThreadPool processing
def filtering(idx_thread):
    ss_thread[idx_thread, :] = ifft(fft(ss[idx_thread, :]) * sig)


start_time = time.time()
pool = ThreadPool()
pool.map(filtering, range(n_row))
print(f'ThreadPool elapsed time : {time.time() - start_time}')


# %% Verify result
if allclose(ss_thread, ss_broad, rtol=1.e-8):
    print('ThreadPool Correct')

if allclose(ss_loop, ss_broad, rtol=1.e-8):
    print('Loop Correct')

结果

loop elapsed time : 5.102990627288818
broadcast elapsed time : 4.520442008972168
ThreadPool elapsed time : 1.6988463401794434

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM