简体   繁体   中英

Connect nodes in a graph when one attribute is the same (NetworkX)

I want to create a graph which automatically adds edges between nodes if one particular attribute is the same. Nodes in my graph represent students. I am adding two attributes to my nodes: university_id and full_name . I want to add an edge between two people only if they go to the same university.

I've been looking at this solution: NetworkX: add edges to graph from node attributes

From testing it, it seems that this solution connects all edges of a graph, whether or not any of the attributes are the same. Is there a simple solution I can use to connect students only based on their university_id ?

Here is my code:

import matplotlib.pyplot as plt
import networkx as nx
import MySQLdb

# #############################################################################
# Retrieve the data from remote server.
myDB = MySQLdb.connect(host="*,port=3306,user="mysql",passwd="***")
cHandler = myDB.cursor()
cHandler.execute("USE research_project")
cHandler.execute("SELECT * FROM students")
results = cHandler.fetchall()

G = nx.Graph()
for items in results:
        # items[0] is a unique ID, items[1] = university_id, items[2] = full name
        G.add_node(items[0], attr_dict={'university_id': items[1], 'full_name': items[2]})

for node_r, attributes in G.nodes(data=True):
    key_set = set(attributes.keys())
    G.add_edges_from([(node_r, node) for node, attributes in G.nodes(data=True)
                      if key_set.intersection(set(attributes.keys()))
                      and node != node_r])

nx.draw(G)
plt.show()
from __future__ import print_function
import matplotlib.pyplot as plt
import networkx as nx
import MySQLdb


# #############################################################################
# Retrieve the data from remote server.
myDB = MySQLdb.connect(host="*,port=3306,user="mysql",passwd="***")
cHandler = myDB.cursor()
cHandler.execute("USE research_project")
cHandler.execute("SELECT * FROM students")
results = cHandler.fetchall()

G = nx.Graph()
for items in results:
    G.add_node(items[0], attr_dict={'university_id': items[1], 'full_name': items[2]})

for node_r in G.nodes(data=True):
    for node in G.nodes(data=True):
        if node != node_r and node[1]['attr_dict']['university_id'] == node_r[1]['attr_dict']['university_id']:
            G.add_edge(node[0], node_r[0], attr_dict=None)

nx.draw(G, with_labels=True)
plt.show()

I tested the above on small sets of data and it appears to work. I have a hunch that what was happening had to do with the way I was adding attributes to the nodes.

The caveat with the above solution is that it's TREMENDOUSLY slow at runtime. I will update my answer whenever I can come up with a faster solution.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM