[英]How to iterate a list of list for a scatter plot and create a legend of unique elements
背景:
我有一个包含x
和y
值的list_of_x_and_y_list
,如下所示:
[[(44800, 14888), (132000, 12500), (40554, 12900)], [(None, 193788), (101653, 78880), (3866, 160000)]]
我有另一个data_name_list
["data_a","data_b"]
以便
"data_a" = [(44800, 14888), (132000, 12500), (40554, 12900)]
"data_b" = [(None, 193788), (101653, 78880), (3866, 160000)]
list_of_x_and_y_list
的len
/ 或data_name_list
len
> 20。
题:
如何为data_name_list
中的每个项目(颜色相同)创建散点图?
我尝试过的:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax = plt.axes(facecolor='#FFFFFF')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
print(list_of_x_and_y_list)
for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
for x_and_y in x_and_y_list,:
print(x_and_y)
x, y = x_and_y
ax.scatter(x, y, label=data_name, color=color) # "label=data_name" creates
# a huge list as a legend!
# :(
plt.title('Matplot scatter plot')
plt.legend(loc=2)
file_name = "3kstc.png"
fig.savefig(file_name, dpi=fig.dpi)
print("Generated: {}".format(file_name))
问题:
传说似乎是一个很长的清单,我不知道如何纠正:
相关研究:
您得到一个长重复列表作为图例的原因是因为您将每个点作为一个单独的系列提供,因为matplotlib
不会根据标签自动对您的数据进行分组。
快速解决方法是迭代列表并将每个系列的 x 值和 y 值压缩为两个元组,以便x
元组包含所有 x 值, y
元组包含 y 值。
然后您可以将这些元组与标签一起提供给plt.plot
方法。
我觉得名字list_of_x_and_y_list
是不必要的冗长和复杂,所以在我的代码中我使用了较短的名字。
import matplotlib.pyplot as plt
data_series = [[(44800, 14888), (132000, 12500), (40554, 12900)],
[(None, 193788), (101653, 78880), (3866, 160000)]]
data_names = ["data_a","data_b"]
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax = plt.axes(facecolor='#FFFFFF')
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
for data, data_name, color in zip(data_series, data_names, colors):
x,y = zip(*data)
ax.scatter(x, y, label=data_name, color=color)
plt.title('Matplot scatter plot')
plt.legend(loc=1)
要让每个 data_name 只获取一个条目,您应该只将 data_name 添加一次作为标签。 其余的调用应该使用label=None
。 使用当前代码可以实现的最简单的方法是在循环结束时将 data_name 设置为None
:
from matplotlib import pyplot as plt
from random import randint
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_facecolor('#FFFFFF')
# create some random data, suppose the sublists have different lengths
list_of_x_and_y_list = [[(randint(1000, 4000), randint(2000, 5000)) for col in range(randint(2, 10))]
for row in range(10)]
data_name_list = list('abcdefghij')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
for x_and_y in x_and_y_list :
x, y = x_and_y
ax.scatter(x, y, label=data_name, color=color)
data_name = None
plt.legend(loc=2)
plt.show()
有些事情可以简化,使代码“更pythonic”,例如:
for x_and_y in x_and_y_list :
x, y = x_and_y
可以写成:
for x, y in x_and_y_list:
另一个问题是,对于每个点调用scatter
的大量数据可能会相当慢。 属于同一列表的所有 x 和 y 可以绘制在一起。 例如使用列表理解:
for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
xs = [x for x, y in x_and_y_list]
ys = [y for x, y in x_and_y_list]
ax.scatter(xs, ys, label=data_name, color=color)
scatter
甚至可以获得每个点的颜色列表,但是一次性绘制所有点,不允许每个data_name
标签。
很多时候, numpy用于存储数值数据。 这有一些优点,例如用于快速计算的矢量化。 使用 numpy 代码如下所示:
import numpy as np
for x_and_y_list, data_name, color in zip(list_of_x_and_y_list, data_name_list, colors):
xys = np.array(x_and_y_list)
ax.scatter(xys[:,0], xys[:,1], label=data_name, color=color)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.