简体   繁体   中英

Update Networkx graph embedded in PyQt5 Matplotlib

I created a GUI where the user can open tabs with button. When he does, it adds the following tab:

class data_tab(QtWidgets.QWidget, data_tab_lib.Ui_data_tab):

#=========================================================================================
# Constructor
#=========================================================================================
def __init__(self, parent, title):

    QtWidgets.QWidget.__init__(self, parent)
    self.setupUi(self)

    # initialize save data button icon
    icon = _createIcon("dataname_save")
    self.dataname_save.setIcon(icon)

    self.canvas = Canvas(data_tab)
    self.axe = self.canvas.figure.add_subplot(111)
    self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
    # add the tab to the parent
    parent.addTab(self, "")

    # set text name
    parent.setTabText(parent.indexOf(self), title)

The parent being

self.core_tab = QtWidgets.QTabWidget(self.main_widget)

in the main window.

"Canvas" is defined as follows:

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt

class Canvas(FigureCanvas):
    def __init__(self, parent=None):
        self.figure = plt.figure()
        FigureCanvas.__init__(self, self.figure)
        self.setParent(parent)

When the initial button is clicked, this happens:

     new_data_tab = data_tab(self.core_tab, dataName)

     # associated data to the variable
     associated_data = self._getDataAssociatedToVariable(dataName)

     #1. draw Graph
     self._drawDataGraph(dataName, associated_data, new_data_tab)

dataName being a string defined above. _drawDataGraph is defined as follows:

def _drawDataGraph(self, dataName, associatedData, dataWidget):

    # 1. draw graph
    drawing_dictionary = self._getAllAssociatedVariablesToBeDrawn(dataName, associatedData)
    producer = drawing_dictionary.get(dataName).get("producer")
    consumers = drawing_dictionary.get(dataName).get("consumers")

    color_map = []
    DG = nx.DiGraph()
    DG.add_node(producer)
    for cons in consumers:
        DG.add_node(cons)
    edges_bunch = [(producer, cons, 1) for cons in consumers]
    DG.add_weighted_edges_from(edges_bunch)

    for node in drawing_dictionary.keys():
        if node != dataName:
            DG.add_node(drawing_dictionary.get(node).get("producer"))
            for node_cons in drawing_dictionary.get(node).get("consumers"):
                DG.add_node(node_cons)

            other_edges_bunch = [(drawing_dictionary.get(node).get("producer"), node_cons, 1) for node_cons in
                                 drawing_dictionary.get(node).get("consumers")]
            DG.add_weighted_edges_from(other_edges_bunch)
    for i in range(len(DG.nodes())):
        if i < 1 + len(consumers):
            color_map.append("#DCE46F")
        else:
            color_map.append("#6FA2E4")
    #pos = nx.spring_layout(DG, k=0.4, iterations=20)
    nx.draw_circular(DG, node_color=color_map, with_labels=True, font_size=8, node_size=1000, node_shape='o')
    dataWidget.canvas.draw()

I won't go through the function _getAllAssociatedVariablesToBeDrawn since it just returns dictionary with keys providing list, it's not the problem here.

So, at the creation (on the initial click on the button), every thing works fine, a fine Networkx diagram is displayed.

My problem is that I have another button where I'd like to refresh all graphs that are currently displayed:

def refreshFlowsDiagramFromDataTabs(self):

    # loop through all pages
    for tab_index in range(self.core_tab.count()):
        data_tab_widget = self.core_tab.widget(tab_index)
        data_tab_widget.axe.cla()

        # associated data to the variable
        data_name = self.core_tab.tabText(tab_index)

        associated_data = self._getDataAssociatedToVariable(data_name)

        # 1. draw graph
        self._drawDataGraph(data_name, associated_data, data_tab_widget)

Sadly, when the button is clicked, only the last graph (so the one of the n-th tab) is refreshed, all the previous one are blank (because they have been cleared through data_tab_widget.axe.cla() )

I tried to comment data_tab_widget.axe.cla() in order to observe what happens: in this case, obviously the 1 to (n-1)th graph are not cleared, but the last one is made with all the previous graphs, ie the n-th graphs are displayed on the last one.

I'm by no mean an expert of matplotlib nor networkx, so I don't get what I'm doing wrong, It's probably a very simple thing but I'd be glad to use some help on the topic.

I hope the code I provided is enough, it should be.

EDIT

Please find below a fully reproductible code : add tabs and then refresh them in order to observe the bug.

import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
import networkx as nx

class Canvas(FigureCanvas):
    def __init__(self, parent=None):
        self.figure = plt.figure()
        FigureCanvas.__init__(self, self.figure)
        self.setParent(parent)

class data_tab(QtWidgets.QWidget):

    #=========================================================================================
    # Constructor
    #=========================================================================================
    def __init__(self, parent, title):

        QtWidgets.QWidget.__init__(self, parent)

        self.data_tab_glayout = QtWidgets.QGridLayout(self)
        self.data_tab_glayout.setObjectName("data_tab_glayout")
        self.canvas = Canvas(self)
        self.canvas.setObjectName("canvas")
        self.canvas_vlayout = QtWidgets.QVBoxLayout(self.canvas)
        self.canvas_vlayout.setObjectName("canvas_vlayout")
        self.data_tab_glayout.addWidget(self.canvas, 0, 0, 2, 1)

        self.axe = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(left=0.025, top=0.965, bottom=0.040, right=0.975)
        # add the tab to the parent
        parent.addTab(self, "")

        # set text name
        parent.setTabText(parent.indexOf(self), title)


class spec_writer(QtWidgets.QMainWindow):


    #=========================================================================================
    # Constructor
    #=========================================================================================
    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)

        self.showMaximized()

        self.centralwidget = QtWidgets.QWidget(self)
        self.centralwidget.setObjectName("centralwidget")
        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.core_tab = QtWidgets.QTabWidget(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.add_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.refresh_tab_btn = QtWidgets.QPushButton(self.centralwidget)
        self.verticalLayout.addWidget(self.refresh_tab_btn)
        self.setCentralWidget(self.centralwidget)

        self.add_tab_btn.setText("Add Tab")
        self.refresh_tab_btn.setText("Refresh Tabs")

        self.core_tab.setEnabled(True)
        self.core_tab.setTabShape(QtWidgets.QTabWidget.Rounded)
        self.core_tab.setElideMode(QtCore.Qt.ElideNone)
        self.core_tab.setDocumentMode(False)
        self.core_tab.setTabsClosable(True)
        self.core_tab.setMovable(True)
        self.core_tab.setTabBarAutoHide(False)

        self.tab_counter = 0

        self.random_tabs = [("a", ["b", "c"]),
                            ("d", ["e", "f", "g"]),
                            ("h", ["i", "j", "k", "l"]),
                            ("m", ["n"]),
                            ("o", ["p", "q"]),
                            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"])]

        self.add_tab_btn.clicked.connect(self.openRandomTab)
        self.refresh_tab_btn.clicked.connect(self.refreshAllTabs)

    def openRandomTab(self):

        tab = data_tab(self.core_tab, "test " + str(self.tab_counter))
        self._drawDataGraph(self.tab_counter % len(self.random_tabs), tab)
        self.tab_counter += 1

        self.core_tab.setCurrentIndex(self.core_tab.indexOf(tab))


    def _drawDataGraph(self, tabNb, dataWidget):

        # 1. draw graph
        producer = self.random_tabs[tabNb][0]
        consumers = self.random_tabs[tabNb][1]

        color_map = []
        DG = nx.DiGraph()
        DG.add_node(producer)
        for cons in consumers:
            DG.add_node(cons)
        edges_bunch = [(producer, cons, 1) for cons in consumers]
        DG.add_weighted_edges_from(edges_bunch)
        for i in range(len(DG.nodes())):
            if i < 1 + len(consumers):
                color_map.append("#DCE46F")
            else:
                color_map.append("#6FA2E4")
        #pos = nx.spring_layout(DG, k=0.4, iterations=20)
        nx.draw_circular(DG, node_color=color_map, with_labels=True, font_size=8, node_size=1000, node_shape='o')
        dataWidget.canvas.draw_idle()


    def refreshAllTabs(self):

        # loop through all pages and associated to get
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)
            data_tab_widget.axe.cla()

            # draw graph
            self._drawDataGraph(tab_index % len(self.random_tabs), data_tab_widget)




sys.argv = ['']
app = QtWidgets.QApplication(sys.argv)
cbtc_spec_writer = spec_writer()
cbtc_spec_writer.show()
app.exec_()

You should not use pyplot if you are going to use the backends since pyplot will only work by default on the current canvas, and by default it is the last one.

For example, the following code (which has been reordered and cleaned) shows how to implement it:

from PyQt5 import QtCore, QtWidgets

from matplotlib.backends.backend_qt5agg import FigureCanvas
from matplotlib.figure import Figure

import networkx as nx


class DataTab(QtWidgets.QWidget):
    def __init__(self, parent=None):
        QtWidgets.QWidget.__init__(self, parent)

        self.canvas = FigureCanvas(Figure(figsize=(5, 3)))
        self.axes = self.canvas.figure.add_subplot(111)
        self.canvas.figure.subplots_adjust(
            left=0.025, top=0.965, bottom=0.040, right=0.975
        )
        data_tab_glayout = QtWidgets.QGridLayout(self)
        data_tab_glayout.addWidget(self.canvas)


class Spec_Writer(QtWidgets.QMainWindow):
    def __init__(self, parent=None):
        QtWidgets.QMainWindow.__init__(self, parent)
        self.showMaximized()

        self.core_tab = QtWidgets.QTabWidget(
            tabShape=QtWidgets.QTabWidget.Rounded,
            elideMode=QtCore.Qt.ElideNone,
            documentMode=False,
            tabsClosable=True,
            movable=True,
            tabBarAutoHide=False,
        )
        self.add_tab_btn = QtWidgets.QPushButton("Add Tab")
        self.refresh_tab_btn = QtWidgets.QPushButton("Refresh Tabs")

        self.centralwidget = QtWidgets.QWidget()
        self.setCentralWidget(self.centralwidget)

        self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
        self.verticalLayout.addWidget(self.core_tab)
        self.verticalLayout.addWidget(self.add_tab_btn)
        self.verticalLayout.addWidget(self.refresh_tab_btn)

        self.tab_counter = 0

        self.random_data = [
            ("a", ["b", "c"]),
            ("d", ["e", "f", "g"]),
            ("h", ["i", "j", "k", "l"]),
            ("m", ["n"]),
            ("o", ["p", "q"]),
            ("r", ["s", "t", "u", "v", "w", "x", "y", "z"]),
        ]

        self.add_tab_btn.clicked.connect(self.open_random_tab)
        self.refresh_tab_btn.clicked.connect(self.refresh_all_tabs)

    def open_random_tab(self):
        tab = DataTab()
        index = self.core_tab.addTab(tab, "test {}".format(self.tab_counter))
        self.core_tab.setCurrentIndex(index)
        self.tab_counter += 1
        self._draw_graph(self.tab_counter % len(self.random_data), tab)

    def _draw_graph(self, index, tab):
        tab.axes.cla()
        producer, consumers = self.random_data[index]
        color_map = []
        DG = nx.DiGraph()
        DG.add_node(producer)
        for cons in consumers:
            DG.add_node(cons)
        edges_bunch = [(producer, cons, 1) for cons in consumers]
        DG.add_weighted_edges_from(edges_bunch)
        for i in range(len(DG.nodes())):
            if i < 1 + len(consumers):
                color_map.append("#DCE46F")
            else:
                color_map.append("#6FA2E4")
        # pos = nx.spring_layout(DG, k=0.4, iterations=20)
        nx.draw_circular(
            DG,
            node_color=color_map,
            with_labels=True,
            font_size=8,
            node_size=1000,
            node_shape="o",
            ax=tab.axes,
        )
        tab.canvas.draw()

    def refresh_all_tabs(self):
        for tab_index in range(self.core_tab.count()):
            data_tab_widget = self.core_tab.widget(tab_index)
            self._draw_graph(tab_index % len(self.random_data), data_tab_widget)


def main():
    app = QtWidgets.QApplication([])
    cbtc_spec_writer = Spec_Writer()
    cbtc_spec_writer.show()
    app.exec_()


if __name__ == "__main__":
    main()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM