简体   繁体   English

使用CSV文件读写NumPy数组的字典

[英]Reading and writing dictionary of NumPy arrays with CSV files

I am using Python 2.7 and networkx to make a spring layout plot of my network. 我正在使用Python 2.7和networkx制作我的网络的弹簧布局图。 In order to compare different setups, I want to store the positions calculated and used by networkx into a file (my choice at the moment: csv) and read it everytime I am making a new plot. 为了比较不同的设置,我想将networkx计算和使用的位置存储到一个文件中(我现在选择:csv)并在每次创建新图时将其读取。 Sounds pretty simple, my code looks like this: 听起来很简单,我的代码看起来像这样:

pos_spring = nx.spring_layout(H, pos=fixed_positions, fixed = fixed_nodes, k = 4, weight='passengers')

This line calculates the positions used later for plotting, that I want to store. 此行计算以后用于绘图的位置,我想要存储。 The dictionary (pos_spring) looks like this: 字典(pos_spring)看起来像这样:

{1536: array([ 0.53892015,  0.984306  ]), 
1025: array([ 0.12096853,  0.82587976]), 
1030: array([ 0.20388712,  0.7046137 ]),

Writing file: 写文件:

w = csv.writer(open("Mexico_spring_layout_positions_2.csv", "w"))
for key, val in pos_spring.items():
    w.writerow([key, val])

The content of the file looks like this: 该文件的内容如下所示:

1536,[ 0.51060853  0.80129841]
1025,[ 0.47442269  0.99838177]
1030,[ 0.02952256  0.45073233]

Reading file: 阅读文件:

with open('Mexico_spring_layout_positions_2.csv', mode='r') as infile:
    reader = csv.reader(infile)
    pos_spring = dict((rows[0],rows[1]) for rows in reader)

The content of pos_spring now looks like this: pos_spring的内容现在看起来像这样:

{'2652': '[ 0.78480322  0.103894  ]', 
'1260': '[ 0.8834103   0.82542163]', 
'2969': '[ 0.33044548  0.31282113]',

Somehow this data looks different from the original dictionary, that was stored inside the csv file. 不知何故,这些数据看起来与原始字典不同,它存储在csv文件中。 What needs to be changed when writing and/or reading the data to fix this issue? 在编写和/或读取数据以解决此问题时需要更改哪些内容? Thanks in advance. 提前致谢。

Kind regards, Frank 亲切的问候,弗兰克

You can't store NumPy arrays in CSV files and maintain data types. 您不能将NumPy阵列存储在CSV文件中并维护数据类型。 Remember, CSV files can only store text. 请记住,CSV文件只能存储文本。 What you are seeing is a text representation of your NumPy array. 你看到的是你的NumPy数组的文本表示。

Instead, you can unpack your NumPy array as you write to your csv file: 相反,您可以在写入csv文件时解压缩NumPy数组:

import csv

d = {1536: np.array([ 0.53892015,  0.984306  ]), 
     1025: np.array([ 0.12096853,  0.82587976]), 
     1030: np.array([ 0.20388712,  0.7046137 ])}

fp = r'C:\temp\out.csv'

with open(fp, 'w', newline='') as fout:
    w = csv.writer(fout)
    for key, val in d.items():
        w.writerow([key, *val])

Then convert back to NumPy when you read back. 然后在您回读时转换回NumPy。 For this step you can use a dictionary comprehension: 对于此步骤,您可以使用字典理解:

with open(fp, 'r') as fin:
    r = csv.reader(fin)
    res = {int(k): np.array(list(map(float, v))) for k, *v in r}

print(res)

{1536: array([ 0.53892015,  0.984306  ]),
 1025: array([ 0.12096853,  0.82587976]),
 1030: array([ 0.20388712,  0.7046137 ])}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM