简体   繁体   English

FFT卷积如何以及为何比直接卷积更快?

[英]How and why is FFT convolution faster than direct convolution?

I read about convolutions being faster when computed into the frequency domain because it's "just" a matrix multiplication (in 2D), while in the time domain it's a lot of small matrix multiplication. 我读到计算到频域时卷积更快,因为它“只是”一个矩阵乘法(在2D中),而在时域中它是很多小矩阵乘法。

So I made this code we can see that FFT convolution is more complex than "normal" convolution. 所以我制作了这段代码,我们可以看到FFT卷积比“正常”卷积更复杂。 It's clear that something is wrong in my assumptions. 很清楚,我的假设出现了问题。

What is wrong ? 怎么了 ?

from sympy import exp, log, symbols, init_printing, lambdify
init_printing(use_latex='matplotlib')

import numpy as np
import matplotlib.pyplot as plt

def _complex_mult(n):
  """Complexity of a MatMul of a 2 matrices of size (n, n)"""
  # see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
  return n**2.5

def _complex_fft(n):
  """Complexity of fft and ifft"""
  # see https://en.wikipedia.org/wiki/Fast_Fourier_transform
  return n*log(n)

def fft_mult_fft(n, m):
  """Complexity of a convolution in the freq space.
  fft -> mult between M and kernel -> ifft
  """
  return _complex_fft(n) * 2 + _complex_mult(n)

def conv(n, m):
  """Complexity of a convolution in the time space.
  for every n of M, we execute a MatMul of 2 (m, m) matrices
  """
  return n*_complex_mult(m)


n = symbols('n') # size of M = (n, n)
m = symbols('m') # size of kernel = (m, m)

M = np.linspace(1, 1e3+1, 1e1)
kernel_size = np.linspace(2, 7, 7-2+1)**2

fft = fft_mult_fft(n, m)
discrete = conv(n, m)

f1 = lambdify(n, fft, 'numpy')
f2 = lambdify([n, m], discrete, 'numpy')

fig, ax = plt.subplots(1, len(kernel_size), figsize=(30, 10))

f1_computed = f1(M) # independant wrt m, do not compute it at each time

for i, size in enumerate(kernel_size):
  ax[i].plot(M, f1_computed, c='red', label='freq domain (fft)')
  ax[i].plot(M, f2(M, size), c='blue', label='time domain (normal)')
  ax[i].legend(loc='upper left')
  ax[i].set_title("kernel size = {}".format(size))
  ax[i].set_xlabel("Matrix size")
  ax[i].set_ylabel("Complexity")

And here is the output: (click to zoom) 这是输出:(点击放大)

图表

You are experiencing two well-known facts: 您遇到两个众所周知的事实:

  • for small kernel sizes, the spatial approach is faster, 对于小内核大小,空间方法更快,

  • for large kernel sizes, the frequency approach can be faster. 对于大内核大小,频率方法可以更快。

Your kernels and images are relatively too small to observe the benefits of the FFT. 您的内核和图像相对太小,无法观察FFT的优势。

As @user545424 pointed out, the problem was that I was computing n*complexity(MatMul(kernel)) instead of n²*complexity(MatMul(kernel)) for a "normal" convolution. 正如@ user545424指出的那样,问题是我正在计算n*complexity(MatMul(kernel))而不是n²*complexity(MatMul(kernel))用于“正常”卷积。

I finally get this: (where n is the size of the input and m the size of the kernel) 我终于明白了:(其中n是输入的大小,m是内核的大小)

复杂


Here is the final code and the new charts. 这是最终代码和新图表。

from sympy import exp, log, symbols, init_printing, lambdify
init_printing(use_latex='matplotlib')

import numpy as np
import matplotlib.pyplot as plt

def _complex_mult(n):
  """Complexity of a MatMul of a 2 matrices of size (n, n)"""
  # see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm
  return n**2.5

def _complex_fft(n):
  """Complexity of fft and ifft"""
  # see https://stackoverflow.com/questions/6514861/computational-complexity-of-the-fft-in-n-dimensions#comment37078975_6516856
  return 4*(n**2)*log(n)

def fft_mult_fft(n, m):
  """Complexity of a convolution in the freq space.
  fft -> mult between M and kernel -> ifft
  """
  return _complex_fft(n) * 2 + _complex_mult(n)

def conv(n, m):
  """Complexity of a convolution in the time space.
  for every n*n cell of M, we execute a MatMul of 2 (m, m) matrices
  """
  return n*n*_complex_mult(m)


n = symbols('n') # size of M = (n, n)
m = symbols('m') # size of kernel = (m, m)

M = np.linspace(1, 1e3+1, 1e1)
kernel_size = np.linspace(2, 7, 7-2+1)**2

fft_symb = fft_mult_fft(n, m)
discrete_symb = conv(n, m)

fft_func = lambdify(n, fft_symb, 'numpy')
dicrete_func = lambdify([n, m], discrete_symb, 'numpy')


fig, ax = plt.subplots(1, len(kernel_size), figsize=(30, 10))
fig.patch.set_facecolor('grey')

for i, size in enumerate(kernel_size):
  ax[i].plot(M, fft_func(M), c='red', label='freq domain (fft)')
  ax[i].plot(M, dicrete_func(M, size), c='blue', label='time domain (normal)')
  ax[i].legend(loc='upper left')
  ax[i].set_title("kernel size = {}".format(size))
  ax[i].set_xlabel("Matrix size")
  ax[i].set_ylabel("Complexity")

图表

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM