[英]How and why is FFT convolution faster than direct convolution?
I read about convolutions being faster when computed into the frequency domain because it's "just" a matrix multiplication (in 2D), while in the time domain it's a lot of small matrix multiplication. 我读到计算到频域时卷积更快,因为它“只是”一个矩阵乘法(在2D中),而在时域中它是很多小矩阵乘法。
So I made this code we can see that FFT convolution is more complex than "normal" convolution. 所以我制作了这段代码,我们可以看到FFT卷积比“正常”卷积更复杂。 It's clear that something is wrong in my assumptions.
很清楚,我的假设出现了问题。
from sympy import exp, log, symbols, init_printing, lambdify
init_printing(use_latex='matplotlib')
import numpy as np
import matplotlib.pyplot as plt
def _complex_mult(n):
"""Complexity of a MatMul of a 2 matrices of size (n, n)"""
# see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
return n**2.5
def _complex_fft(n):
"""Complexity of fft and ifft"""
# see https://en.wikipedia.org/wiki/Fast_Fourier_transform
return n*log(n)
def fft_mult_fft(n, m):
"""Complexity of a convolution in the freq space.
fft -> mult between M and kernel -> ifft
"""
return _complex_fft(n) * 2 + _complex_mult(n)
def conv(n, m):
"""Complexity of a convolution in the time space.
for every n of M, we execute a MatMul of 2 (m, m) matrices
"""
return n*_complex_mult(m)
n = symbols('n') # size of M = (n, n)
m = symbols('m') # size of kernel = (m, m)
M = np.linspace(1, 1e3+1, 1e1)
kernel_size = np.linspace(2, 7, 7-2+1)**2
fft = fft_mult_fft(n, m)
discrete = conv(n, m)
f1 = lambdify(n, fft, 'numpy')
f2 = lambdify([n, m], discrete, 'numpy')
fig, ax = plt.subplots(1, len(kernel_size), figsize=(30, 10))
f1_computed = f1(M) # independant wrt m, do not compute it at each time
for i, size in enumerate(kernel_size):
ax[i].plot(M, f1_computed, c='red', label='freq domain (fft)')
ax[i].plot(M, f2(M, size), c='blue', label='time domain (normal)')
ax[i].legend(loc='upper left')
ax[i].set_title("kernel size = {}".format(size))
ax[i].set_xlabel("Matrix size")
ax[i].set_ylabel("Complexity")
And here is the output: (click to zoom) 这是输出:(点击放大)
You are experiencing two well-known facts: 您遇到两个众所周知的事实:
for small kernel sizes, the spatial approach is faster, 对于小内核大小,空间方法更快,
for large kernel sizes, the frequency approach can be faster. 对于大内核大小,频率方法可以更快。
Your kernels and images are relatively too small to observe the benefits of the FFT. 您的内核和图像相对太小,无法观察FFT的优势。
As @user545424 pointed out, the problem was that I was computing n*complexity(MatMul(kernel))
instead of n²*complexity(MatMul(kernel))
for a "normal" convolution. 正如@ user545424指出的那样,问题是我正在计算
n*complexity(MatMul(kernel))
而不是n²*complexity(MatMul(kernel))
用于“正常”卷积。
I finally get this: (where n is the size of the input and m the size of the kernel) 我终于明白了:(其中n是输入的大小,m是内核的大小)
Here is the final code and the new charts. 这是最终代码和新图表。
from sympy import exp, log, symbols, init_printing, lambdify
init_printing(use_latex='matplotlib')
import numpy as np
import matplotlib.pyplot as plt
def _complex_mult(n):
"""Complexity of a MatMul of a 2 matrices of size (n, n)"""
# see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
return n**2.5
def _complex_fft(n):
"""Complexity of fft and ifft"""
# see https://stackoverflow.com/questions/6514861/computational-complexity-of-the-fft-in-n-dimensions#comment37078975_6516856
return 4*(n**2)*log(n)
def fft_mult_fft(n, m):
"""Complexity of a convolution in the freq space.
fft -> mult between M and kernel -> ifft
"""
return _complex_fft(n) * 2 + _complex_mult(n)
def conv(n, m):
"""Complexity of a convolution in the time space.
for every n*n cell of M, we execute a MatMul of 2 (m, m) matrices
"""
return n*n*_complex_mult(m)
n = symbols('n') # size of M = (n, n)
m = symbols('m') # size of kernel = (m, m)
M = np.linspace(1, 1e3+1, 1e1)
kernel_size = np.linspace(2, 7, 7-2+1)**2
fft_symb = fft_mult_fft(n, m)
discrete_symb = conv(n, m)
fft_func = lambdify(n, fft_symb, 'numpy')
dicrete_func = lambdify([n, m], discrete_symb, 'numpy')
fig, ax = plt.subplots(1, len(kernel_size), figsize=(30, 10))
fig.patch.set_facecolor('grey')
for i, size in enumerate(kernel_size):
ax[i].plot(M, fft_func(M), c='red', label='freq domain (fft)')
ax[i].plot(M, dicrete_func(M, size), c='blue', label='time domain (normal)')
ax[i].legend(loc='upper left')
ax[i].set_title("kernel size = {}".format(size))
ax[i].set_xlabel("Matrix size")
ax[i].set_ylabel("Complexity")
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.