簡體   English   中英

在pyqtgraph中更改散點圖plot colors的更有效方法

[英]More efficient way to change scatter plot colors in pyqtgraph

背景:我正在創建一個 GUI,它將顯示 4 個六邊形像素的 arrays 中的更新當前值。 我認為最簡單的方法是在 pyqtgraph 中創建布局的分散 plot 並根據傳感器數據更新面部/畫筆顏色。

方法:我基本上抄襲了 pyqtgraph 示例腳本 ScatterPlotItem.py 和 ScatterPlotSpeedTest.py 並將它們調整到我的特定布局。

問題:性能非常慢,比我預期的要慢得多,因為示例 pyqtgraph 腳本以 1000+ fps 運行。 我的腳本目前以 ~4-7 fps 的速度運行。 這是相當令人驚訝的,因為我假設只改變一個點的顏色會很快。 我還在學習 pyqtgraph 並在 function update中使用ScatterPlotItem.setBrush() ,但它似乎很慢(我認為這是更新緩慢的根源)。 是否有更好/更快的方法來更新分散 plot 項目的面顏色?

這是我正在使用的腳本:

import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtWidgets, QtCore
from time import perf_counter


# This function generates the hexgonal array x's and y's in the required ordering.
# It is ugly, but works for now.
def drawHexGridLoop2(origin, depth, apothem, padding):
    
    def getCoords(xs, ys):
        xs = [item for sublist in xs for item in sublist]
        ys = [item for sublist in ys for item in sublist]
        coords = list(zip(xs, ys))      
        return coords
    
    def flattenList(l):
        rv = [item for sublist in l for item in sublist]
        return rv
    
    ang60 = np.deg2rad(60)
    xs = [[origin[0]]]
    ys = [[origin[1]]]
    labels = [['1']]
    labelN = 2
    thisX = 0
    thisY = 0
    for d in range(1, depth):
        thisXArr = []
        thisYArr = []
        thisLabelArr = []
        loc = 1
        n=0
        while n < d*6:
            if n == 0:
                anchorN = 0
                thisX = round(xs[-1][0] + 2*apothem, 8)
                thisY = round(ys[-1][0], 8)
                anchorX = xs[-1][anchorN]
                anchorY = ys[-1][anchorN]
                thisXArr.append(thisX)
                thisYArr.append(thisY)
                thisLabelArr.append(str(labelN))
                labelN += 1

            else:
                thisX = round(anchorX + 2*apothem*np.cos(-1*ang60*loc), 8)
                thisY = round(anchorY + 2*apothem*np.sin(-1*ang60*loc), 8)
                if (thisX, thisY) in getCoords(xs, ys):
                    anchorN += 1
                    anchorX = xs[-1][anchorN]
                    anchorY = ys[-1][anchorN]
                    loc -= 1
                    continue
                thisXArr.append(thisX)
                thisYArr.append(thisY)
                thisLabelArr.append(str(labelN))
                labelN += 1
                loc += 1
            n += 1              
        xs.append(thisXArr)
        ys.append(thisYArr)
        labels.append(thisLabelArr)
    xs = flattenList(xs)
    ys = flattenList(ys)
    labels = flattenList(labels)
    return xs, ys, labels


# Function to create the scatter plot in each viewbox.
# Adapted from ScatterPlotItem.py
def createArray(w):
    s = pg.ScatterPlotItem(
        pxMode=False,  # Set pxMode=False to allow spots to transform with the view
        hoverable=True,
        hoverPen=pg.mkPen('g'),
        hoverSize=hexSize
    )
    spots = []
    xs, ys, labels = drawHexGridLoop2((0, 0), 14, 1e-6, 0)
    for i, thing in enumerate(xs):
        spots.append({'pos': (xs[i], ys[i]), 'size': hexSize, 'pen': {'color': 'w', 'width': 2}, 'brush':pg.intColor(10, 10), 'symbol':'h'})
    s.addPoints(spots)
    w.addItem(s)

    return w, s, spots, xs, ys

hexSize = 2.2e-6
app = pg.mkQApp("Scatter Plot Item Example") 
mw = QtWidgets.QMainWindow()
mw.resize(800,800)
view = pg.GraphicsLayoutWidget()  ## GraphicsView with GraphicsLayout inserted by default
mw.setCentralWidget(view)
mw.show()
mw.setWindowTitle('pyqtgraph example: ScatterPlot')
view.ci.setBorder((50, 50, 100))

## create four areas to add plots
w1 = view.addViewBox()
w1.setAspectLocked()
w2 = view.addViewBox()
w2.setAspectLocked()
view.nextRow()
w3 = view.addViewBox()
w3.setAspectLocked()
w4 = view.addViewBox()
w4.setAspectLocked()

# Create the scatter plots.
w1, s1, spots1, xs, ys = createArray(w1)
w2, s2, spots1, xs, ys = createArray(w2)
w3, s3, spots1, xs, ys = createArray(w3)
w4, s4, spots1, xs, ys = createArray(w4)

# Create the color map.
# Adapted from https://github.com/pyqtgraph/pyqtgraph/issues/1712#issuecomment-819745370
nPts = 255
colormap = pg.colormap.get('cividis')
valueRange = np.linspace(0, 255, num=nPts)
colors = colormap.getLookupTable(0, 1, nPts=nPts)

# This is really slow!
fps = None
lastTime = perf_counter()
def update():
    global fps, lastTime
    z = np.random.randint(0,255, size=547)
    brushes = colors[np.searchsorted(valueRange, z)]
    s1.setBrush(brushes) # Is there a faster way to do this?
    s2.setBrush(brushes)
    s3.setBrush(brushes)
    s4.setBrush(brushes)
    now = perf_counter()
    dt = now - lastTime
    lastTime = now
    if fps is None:
        fps = 1.0 / dt
    else:
        s = np.clip(dt * 3., 0, 1)
        fps = fps * (1 - s) + (1.0 / dt) * s
    mw.setWindowTitle('%0.2f fps' % fps)

    
timer = QtCore.QTimer()
timer.timeout.connect(update)
timer.start(0)

if __name__ == '__main__':
    pg.exec()

您的解決方案的問題是,您將 colors 數組傳遞給 setBrush 方法。 在這種情況下,PyqtGraph 必須為每一幀生成新的畫筆。 這是你的瓶頸。

為了加快速度,您可以創建畫筆表,其中將包含畫筆而不是 colors。 然后在更新 function 中,您只需生成索引數組並從中生成新的畫筆數組。 現在 PyqtGraph 將使用這些畫筆,而不是為每一幀創建一個新畫筆。
性能從約 7fps 提高到 35fps。

這是您修改后的示例:

import time

import numpy as np
import pyqtgraph as pg
from pyqtgraph.Qt import QtWidgets, QtCore, QtGui
from time import perf_counter


# This function generates the hexgonal array x's and y's in the required ordering.
# It is ugly, but works for now.
def drawHexGridLoop2(origin, depth, apothem, padding):
    def getCoords(xs, ys):
        xs = [item for sublist in xs for item in sublist]
        ys = [item for sublist in ys for item in sublist]
        coords = list(zip(xs, ys))
        return coords

    def flattenList(l):
        rv = [item for sublist in l for item in sublist]
        return rv

    ang60 = np.deg2rad(60)
    xs = [[origin[0]]]
    ys = [[origin[1]]]
    labels = [['1']]
    labelN = 2
    thisX = 0
    thisY = 0
    for d in range(1, depth):
        thisXArr = []
        thisYArr = []
        thisLabelArr = []
        loc = 1
        n = 0
        while n < d * 6:
            if n == 0:
                anchorN = 0
                thisX = round(xs[-1][0] + 2 * apothem, 8)
                thisY = round(ys[-1][0], 8)
                anchorX = xs[-1][anchorN]
                anchorY = ys[-1][anchorN]
                thisXArr.append(thisX)
                thisYArr.append(thisY)
                thisLabelArr.append(str(labelN))
                labelN += 1

            else:
                thisX = round(anchorX + 2 * apothem * np.cos(-1 * ang60 * loc), 8)
                thisY = round(anchorY + 2 * apothem * np.sin(-1 * ang60 * loc), 8)
                if (thisX, thisY) in getCoords(xs, ys):
                    anchorN += 1
                    anchorX = xs[-1][anchorN]
                    anchorY = ys[-1][anchorN]
                    loc -= 1
                    continue
                thisXArr.append(thisX)
                thisYArr.append(thisY)
                thisLabelArr.append(str(labelN))
                labelN += 1
                loc += 1
            n += 1
        xs.append(thisXArr)
        ys.append(thisYArr)
        labels.append(thisLabelArr)
    xs = flattenList(xs)
    ys = flattenList(ys)
    labels = flattenList(labels)
    return xs, ys, labels


# Function to create the scatter plot in each viewbox.
# Adapted from ScatterPlotItem.py
def createArray(w):
    s = pg.ScatterPlotItem(
        pxMode=False,  # Set pxMode=False to allow spots to transform with the view
        hoverable=True,
        hoverPen=pg.mkPen('g'),
        hoverSize=hexSize
    )
    spots = []
    xs, ys, labels = drawHexGridLoop2((0, 0), 14, 1e-6, 0)
    for i, thing in enumerate(xs):
        spots.append(
            {'pos': (xs[i], ys[i]), 'size': hexSize, 'pen': {'color': 'w', 'width': 2}, 'brush': pg.intColor(10, 10),
             'symbol': 'h'})
    s.addPoints(spots)
    w.addItem(s)

    return w, s, spots, xs, ys


hexSize = 2.2e-6
app = pg.mkQApp("Scatter Plot Item Example")
mw = QtWidgets.QMainWindow()
mw.resize(800, 800)
view = pg.GraphicsLayoutWidget()  ## GraphicsView with GraphicsLayout inserted by default
mw.setCentralWidget(view)
mw.show()
mw.setWindowTitle('pyqtgraph example: ScatterPlot')
view.ci.setBorder((50, 50, 100))

## create four areas to add plots
w1 = view.addViewBox()
w1.setAspectLocked()
w2 = view.addViewBox()
w2.setAspectLocked()
view.nextRow()
w3 = view.addViewBox()
w3.setAspectLocked()
w4 = view.addViewBox()
w4.setAspectLocked()

# Create the scatter plots.
w1, s1, spots1, xs, ys = createArray(w1)
w2, s2, spots1, xs, ys = createArray(w2)
w3, s3, spots1, xs, ys = createArray(w3)
w4, s4, spots1, xs, ys = createArray(w4)

# Create the color map.
# Adapted from https://github.com/pyqtgraph/pyqtgraph/issues/1712#issuecomment-819745370
nPts = 255
colormap = pg.colormap.get('cividis')
valueRange = np.linspace(0, 255, num=nPts)
colors = colormap.getLookupTable(0, 1, nPts=nPts)

# *** Create brushes lookup table
brushes_table = [QtGui.QBrush(QtGui.QColor(*color)) for color in colors]

fps = None
lastTime = perf_counter()


def update():
    global fps, lastTime
    z = np.random.randint(0, 255, size=547)
    # *** Create array of already created brushes from lookup table
    brushes = [brushes_table[i] for i in z]
    s1.setBrush(brushes)  # Now update is much faster!
    s2.setBrush(brushes)
    s3.setBrush(brushes)
    s4.setBrush(brushes)
    now = perf_counter()
    dt = now - lastTime
    lastTime = now
    if fps is None:
        fps = 1.0 / dt
    else:
        s = np.clip(dt * 3., 0, 1)
        fps = fps * (1 - s) + (1.0 / dt) * s
    mw.setWindowTitle('%0.2f fps' % fps)


timer = QtCore.QTimer()
timer.timeout.connect(update)
timer.start(0)

if __name__ == '__main__':
    pg.exec()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM