简体   繁体   中英

squared nodes in networkx with matplotlib

Before starting, I will point out that it might seem as if this question is a duplicate of this one , but the solution here simply doesnt compile in python3 with a current version of networkx. The set won't construct itself etc.

So I have a networkx graph, that I draw using matplotlib. Here is the code for it:

class Vert:
  
    # default constructor
    def __init__(self, name, size, edges):
        self.name = name
        self.size = size
        self.edges = edges


import networkx as nx
import matplotlib.pyplot as plt

nodes = []
nodes.append(Vert('A', 1, ['B', 'C']))
nodes.append(Vert('B', 3, ['D']))
nodes.append(Vert('C', 4, ['D']))
nodes.append(Vert('D', 7, []))
nodes.append(Vert('Y', 64, []))

G = nx.DiGraph()
for v  in nodes:
    G.add_node(v.name, s='v')
    for e in v.edges:
        G.add_edge(v.name, e)

node_sizes = [V.size * 100 for V in nodes]
shapes = set((aShape[1]["s"] for aShape in G.nodes(data = True)))

nx.draw(G, font_weight='bold', with_labels = True, node_size=node_sizes, node_shape= shapes)

#plt.savefig('plot.png', bbox_inches='tight')
plt.show()

I need some of the nodes to have a square or maybe triangled shape, how do I do this?

The code in the old answer fails to run because the syntax of add_path() has changed since the post was written. I edited the answer in the older question, but it won't show immediately since I don't yet have edit approval privileges.

If you replace

G.add_path([0,2,5])
G.add_path([1,4,3,0])
G.add_path([2,4,0,5])

with

nx.add_path(G, [0,2,5])
nx.add_path(G, [1,4,3,0])
nx.add_path(G, [2,4,0,5])

then I believe it should run successfully.


EDIT: In response to comment below, here is an example of working code that takes into account both shape and size. It isn't particularly clean or consistent in style, but it combines the method from the previous SO question with the questioner's data generation scheme.

The key part is changing from using nx.draw() to separately drawing all parts of the graph using nx.draw_networkx_nodes() , nx.draw_networkx_edges() , and nx.draw_networkx_labels() . See the networkx drawing docs for full details. This change allows for drawing each set of nodes that has a different shape with a different call of nx.draw_networkx_nodes() .

I did some rather inelegant things to adjust the plot including adjusting plt.xlim , plt.ylim , and the spacing argument ( k ) of nx.layout.spring_layout() .

The code below gives the following plot:

network_viz_w_shapes_sizes

class Vert:
  
    # default constructor
    def __init__(self, name, size, edges):
        self.name = name
        self.size = size
        self.edges = edges


import networkx as nx
import matplotlib.pyplot as plt

nodes = []
nodes.append(Vert('A', 1, ['B', 'C']))
nodes.append(Vert('B', 3, ['D']))
nodes.append(Vert('C', 4, ['D']))
nodes.append(Vert('D', 7, []))
nodes.append(Vert('Y', 64, []))

G = nx.DiGraph()

for v  in nodes:
    # Assign 'v' shape to even nodes and square shape to odd nodes.
    if ord(v.name) % 2 == 0:
        G.add_node(v.name, size=v.size, shape='v')
    else:
        G.add_node(v.name, size=v.size, shape='s')
    for e in v.edges:
        G.add_edge(v.name, e)

shapes = set((aShape[1]['shape'] for aShape in G.nodes(data = True)))
pos = nx.layout.spring_layout(G, k=2) #Make k larger to space out nodes more.

for shape in shapes:
    nodelist=[node[0] for node in filter(lambda x: x[1]['shape']==shape,G.nodes(data = True))]
    sizes = [100 * node[1]['size'] for node in filter(lambda x: x[1]['shape']==shape,G.nodes(data = True))]
    #...filter and draw the subset of nodes with the same symbol in the positions that are now known through the use of the layout.
    nx.draw_networkx_nodes(G,
                           pos,
                           node_shape=shape,
                           nodelist=nodelist,
                           node_size=sizes)

# Draw the edges between the nodes and label them
nx.draw_networkx_edges(G,pos)
nx.draw_networkx_labels(G, pos)

plt.xlim(-2, 2) # Expand limits if large nodes spill over plot.
plt.ylim(-2, 2) # Expand limits if large nodes spill over plot.

plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM