簡體   English   中英

如何加速這個用 Python 編寫的程序?

[英]How can I speed up this program written in Python?

以下程序是 Python 中行進廣場問題的解決方案:

from typing import List

def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return []

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    linesList = linesList + list
    else:
        raise AssertionError

    return [linesList]

此源代碼的問題是 - 生成輸出需要很長時間。

即使用以下驅動程序:

import drawSvg as draw_svg

N_int = 800
N2_float = N_int / 8
x_int_vector = [i for i in range(N_int)]
y_int_vector = [i for i in range(N_int)]

matrix_256x256 = [[(math.sin(i / N2_float) * math.sin(j / N2_float)) for i in range(N_int)] for j in range(N_int)]

fill = "#2591a3"
drawing = draw_svg.Drawing(N_int, N_int, displayInline=False)

threshold_float_list = [0.2, 0.4, 0.6, 0.8]
collection = marching_square(x_int_vector, y_int_vector, matrix_256x256, threshold_float_list)
for line_set in collection:
    for line in line_set:
        drawing.append(draw_svg.Line(line[0], line[1], line[2], line[3], stroke='red'))
     # END of line
# END of line_set
drawing.saveSvg('example.svg') 

該代碼在實際使用中變得非常緩慢。

我怎樣才能加速代碼?

NB marching_square()的簽名不得更改。

獲得了約 10 倍的加速

  1. 刪除了最大瓶頸的擴展列表
  2. numba應用於第二個瓶頸的GetCaseId
from typing import List
import numba
import functools
import operator

@numba.jit(nopython=True)
def GetCaseId(Point_A_data: float, Point_B_data: float,
              Point_C_data: float, Point_D_data: float,
              threshold):
    caseId = 0
    if (Point_A_data >= threshold):
        caseId |= 1
    if (Point_B_data >= threshold):
        caseId |= 2
    if (Point_C_data >= threshold):
        caseId |= 4
    if (Point_D_data >= threshold):
        caseId |= 8
    return caseId


def GetLines(Point_A: List[float], Point_B: List[float], Point_C: List[float], Point_D: List[float],
             a: float, b: float, c: float, d: float,
             threshold: float):
    lines = []
    caseId = GetCaseId(a, b, c, d, threshold)

    if caseId in (0, 15):
        return None

    if caseId in (1, 14, 10):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_B[1]
        qX = Point_D[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (2, 13, 5):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = Point_C[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (3, 12):
        pX = Point_A[0]
        pY = (Point_A[1] + Point_D[1]) / 2
        qX = Point_C[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    if caseId in (4, 11, 10):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_D[1]
        qX = Point_B[0]
        qY = (Point_B[1] + Point_C[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (6, 9):
        pX = (Point_A[0] + Point_B[0]) / 2
        pY = Point_A[1]
        qX = (Point_C[0] + Point_D[0]) / 2
        qY = Point_C[1]

        line = (pX, pY, qX, qY)

        lines.append(line)

    elif caseId in (7, 8, 5):
        pX = (Point_C[0] + Point_D[0]) / 2
        pY = Point_C[1]
        qX = Point_A[0]
        qY = (Point_A[1] + Point_D[1]) / 2

        line = (pX, pY, qX, qY)

        lines.append(line)

    return lines


def marching_square(x_int_list, y_int_list, data_2d_list, threshold_list):
    linesList = []

    Height = len(y_int_list)  # rows
    Width = len(x_int_list)  # cols

    if ((Width == len(data_2d_list[0])) and (Height == len(data_2d_list))):

        for j in range(Height - 1):  # rows
            for i in range(Width - 1):  # cols
                point_a_data_float = data_2d_list[j + 1][i]
                point_b_data_float = data_2d_list[j + 1][i + 1]
                point_c_data_float = data_2d_list[j][i + 1]
                point_d_data_float = data_2d_list[j][i]

                point_A = [x_int_list[i], y_int_list[j + 1]]
                point_B = [x_int_list[i + 1], y_int_list[j + 1]]
                point_C = [x_int_list[i + 1], y_int_list[j]]
                point_D = [x_int_list[i], y_int_list[j]]

                for threshold_item in threshold_list:
                    list = GetLines(point_A, point_B, point_C, point_D,
                                    point_a_data_float, point_b_data_float, point_c_data_float, point_d_data_float,
                                    threshold_item)
                    if list:
                        linesList.append(list)

    else:
        raise AssertionError

    return functools.reduce(operator.iconcat, linesList, [])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM