简体   繁体   English

Python 3D插值加速

[英]Python 3D interpolation speedup

I have following code used to interpolate 3D volume data. 我有以下用于插入3D体积数据的代码。

Y, X, Z = np.shape(volume)
xs = np.arange(0, X)
ys = np.arange(0, Y)
zs = np.arange(0, Z)

points = list(zip(np.ravel(result[:, :, :, 1]), np.ravel(result[:, :, :, 0]), np.ravel(result[:, :, :, 2])))
interp = interpolate.RegularGridInterpolator((ys, xs, zs), volume,
                                             bounds_error=False, fill_value=0, method='linear')
new_volume = interp(points)
new_volume = np.reshape(new_volume, (Y, X, Z))

This code takes about 37 seconds to execute on 512x512x110 volume (about 29 millions of points), which results in more than one microsecond per voxel (which is unacceptable amount of time for me - what is more it uses 4 cores). 此代码在512x512x110卷(大约2900万个点)上执行大约需要37秒,这导致每个体素超过1微秒(这对我来说是不可接受的时间 - 更多的是它使用4个核心)。 Call new_volume=interp(points) takes about 80% of the prodecure time and the list creation almost whole remaining time. 调用new_volume=interp(points)需要大约80%的prodecure时间和列表创建几乎整个剩余时间。

Is there any simple (or even more complex) way to make this computation faster? 是否有任何简单(甚至更复杂)的方法可以使这种计算更快? Or is there any good Python library, which provides faster interpolation? 或者是否有任何优秀的Python库,它提供更快的插值? My volume and points change in every call to this prodecure. 每次打电话给我这个职业,我的数量和积分都会发生变化。

Here is slightly modified version of your cython solution: 这是你的cython解决方案的略微修改版本:

import numpy as np
cimport numpy as np
from libc.math cimport floor
from cython cimport boundscheck, wraparound, nonecheck, cdivision

DTYPE = np.float
ctypedef np.float_t DTYPE_t

def interp3D(DTYPE_t[:,:,::1] v, DTYPE_t[:,:,::1] xs, DTYPE_t[:,:,::1] ys, DTYPE_t[:,:,::1] zs):

    cdef int X, Y, Z
    X,Y,Z = v.shape[0], v.shape[1], v.shape[2]
    cdef np.ndarray[DTYPE_t, ndim=3] interpolated = np.zeros((X, Y, Z), dtype=DTYPE)

    _interp3D(&v[0,0,0], &xs[0,0,0], &ys[0,0,0], &zs[0,0,0], &interpolated[0,0,0], X, Y, Z)
    return interpolated

cdef inline void _interp3D(DTYPE_t *v, DTYPE_t *x_points, DTYPE_t *y_points, DTYPE_t *z_points, 
               DTYPE_t *result, int X, int Y, int Z):

        int i, x0, x1, y0, y1, z0, z1, dim
        DTYPE_t x, y, z, xd, yd, zd, c00, c01, c10, c11, c0, c1, c

    dim = X*Y*Z

    for i in range(dim):
        x = x_points[i]
        y = y_points[i]
        z = z_points[i]

        x0 = <int>floor(x)
        x1 = x0 + 1
        y0 = <int>floor(y)
        y1 = y0 + 1
        z0 = <int>floor(z)
        z1 = z0 + 1

        xd = (x-x0)/(x1-x0)
        yd = (y-y0)/(y1-y0)
        zd = (z-z0)/(z1-z0)

        if x0 >= 0 and y0 >= 0 and z0 >= 0:
            c00 = v[Y*Z*x0+Z*y0+z0]*(1-xd) + v[Y*Z*x1+Z*y0+z0]*xd
            c01 = v[Y*Z*x0+Z*y0+z1]*(1-xd) + v[Y*Z*x1+Z*y0+z1]*xd
            c10 = v[Y*Z*x0+Z*y1+z0]*(1-xd) + v[Y*Z*x1+Z*y1+z0]*xd
            c11 = v[Y*Z*x0+Z*y1+z1]*(1-xd) + v[Y*Z*x1+Z*y1+z1]*xd

            c0 = c00*(1-yd) + c10*yd
            c1 = c01*(1-yd) + c11*yd

            c = c0*(1-zd) + c1*zd

            c = 0

        result[i] = c 

The results are still identical to yours. 结果仍然与您的相同。 With a random grid data of 60x60x60 I obtain the following timings: 使用60x60x60的随机网格数据,我获得以下时间:

SciPy's solution:                982ms
Your cython solution:            24.7ms
Above modified cython solution:  8.17ms

So its nearly 4 times faster than your cython solution. 所以它比你的cython解决方案快近4倍。 Note that 注意

  1. Cython performs bounds checking by default on arrays and if you want that feature then turn on boundschecking remove @boundscheck(False) . Cython默认在数组上执行边界检查,如果你想要该功能,则打开boundschecking删除@boundscheck(False)
  2. In the above solution the arrays are required to be C-contiguous 在上述解决方案中,阵列需要是C-连续的
  3. If you want a parallel variant of the above code, replace range instead of prange in your for loop . 如果您想要上述代码的并行变体, prangefor loop替换range而不是prange

Hope this helps. 希望这可以帮助。

I used Cython to accelerate this and implemented following code: 我使用Cython来加速这个并实现以下代码:

import numpy as np
cimport numpy as np
from libc.math cimport ceil, floor

DTYPE = np.float
ctypedef np.float_t DTYPE_t

def interp3(np.ndarray[DTYPE_t, ndim=3] x_grid, np.ndarray[DTYPE_t, ndim=3] y_grid,
    np.ndarray[DTYPE_t, ndim=3] z_grid, np.ndarray[DTYPE_t, ndim=3] v,
    np.ndarray[DTYPE_t, ndim=3] xs, np.ndarray[DTYPE_t, ndim=3] ys, 
    np.ndarray[DTYPE_t, ndim=3] zs):

    cdef int i
    cdef float x
    cdef float y
    cdef float z
    cdef int x0
    cdef int x1
    cdef int y0
    cdef int y1
    cdef int z0
    cdef int z1
    cdef float xd
    cdef float yd
    cdef float zd
    cdef float c00
    cdef float c01
    cdef float c10
    cdef float c11
    cdef float c0
    cdef float c1
    cdef float c
    cdef int X
    cdef int Y
    cdef int Z

    X, Y, Z = np.shape(x_grid)

    cdef np.ndarray[DTYPE_t, ndim=1] x_points = np.ravel(xs)
    cdef np.ndarray[DTYPE_t, ndim=1] y_points = np.ravel(ys)
    cdef np.ndarray[DTYPE_t, ndim=1] z_points = np.ravel(zs)
    cdef np.ndarray[DTYPE_t, ndim=1] result = np.empty((len(x_points)), dtype=DTYPE)

    for i in range(len(x_points)):
        x = x_points[i]
        y = y_points[i]
        z = z_points[i]

        x0 = int(floor(x))
        x1 = x0 + 1
        y0 = int(floor(y))
        y1 = y0 + 1
        z0 = int(floor(z))
        z1 = z0 + 1

        xd = (x-x0)/(x1-x0)
        yd = (y-y0)/(y1-y0)
        zd = (z-z0)/(z1-z0)

            assert x0 >= 0 and y0 >= 0 and z0 >= 0
            c00 = v[x0, y0, z0]*(1-xd) + v[x1, y0, z0]*xd
            c01 = v[x0, y0, z1]*(1-xd) + v[x1, y0, z1]*xd
            c10 = v[x0, y1, z0]*(1-xd) + v[x1, y1, z0]*xd
            c11 = v[x0, y1, z1]*(1-xd) + v[x1, y1, z1]*xd

            c0 = c00*(1-yd) + c10*yd
            c1 = c01*(1-yd) + c11*yd

            c = c0*(1-zd) + c1*zd
            c = 0

        result[i] = c

    cdef np.ndarray[DTYPE_t, ndim=3] interpolated = np.zeros((X, Y, Z), dtype=DTYPE)
    interpolated = np.reshape(result, (X, Y, Z))
    return interpolated  

It is my first time with Cython, so have following questions: 这是我第一次使用Cython,所以有以下问题:

  1. How can I optimize this further? 我该如何进一步优化?

  2. Is there any easy way to avoid try and assert statements to check array bounds? 有没有简单的方法来避免try和assert语句来检查数组边界? Trying to ensure bounds with min/max combinations is slower than this try/assert approach 尝试确保与最小/最大组合的边界比这个try / assert方法慢

Currently, it is around 8x faster than original code posted above. 目前,它比上面发布的原始代码快约8倍。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM