[英]Add legend to a matplotlib scatter plot where colors are dynamic
我使用 Matplolib 和 Pandas dataframe 創建了 Scatter plot,現在我想為其添加一個圖例。 這是我的代碼:
colors = ['red' if x >= 150 and x < 200 else
'green' if x >= 200 and x < 400 else
'purple' if x >= 400 and x < 600 else
'yellow' if x >= 600 else 'teal' for x in myData.R]
ax1.scatter(myData.X, myData.Y, s=20, c=colors, marker='_', label='Test')
ax1.legend(loc='upper left', frameon=False)
這里發生的情況是,根據myData.R
的值,散點圖 plot 中點的顏色會發生變化。 所以,由於顏色是“動態的”,我在創建圖例時遇到了很多麻煩。 實際代碼只會創建一個帶有單個 label 的圖例,稱為“測試”,附近沒有任何顏色。
以下是數據示例:
X Y R
0 1 945 1236.334519
0 1 950 212.809352
0 1 950 290.663847
0 1 961 158.156856
我試過這個,但我不明白的是:
如何動態為圖例設置標簽? 例如, 'red' if x >= 150
,因此在圖例上應該有一個紅色方塊,其附近>150 。 但由於我沒有手動添加任何 label,我很難理解這一點。
在嘗試了以下之后,我只得到了一個帶有單個 label 'Classes' 的圖例:
`legend1 = ax1.legend(*scatter.legend_elements(), loc="左下", title="類")
ax1.add_artist(legend1)`
任何形式的建議表示贊賞!
可以加速的部分代碼是使用普通的 Python 循環創建字符串列表。 Pandas 使用 numpy 的過濾非常有效。 繪制散點圖主要取決於點的數量,當所有點一次繪制或分五部分繪制時,點數不會改變。
在循環中使用 matplotlib 的 scatter 的一些示例代碼:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
N = 500
myData = pd.DataFrame({'X': np.round(np.random.uniform(-1000, 1000, N), -2),
'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)
fig, ax1 = plt.subplots()
bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
for b0, b1, col in zip([None]+bounds, bounds+[None], colors):
if b0 is None:
filter = (myData.R < b1)
label = f'$ R < {b1} $'
elif b1 is None:
filter = (myData.R >= b0)
label = f'${b0} ≤ R $'
else:
filter = (myData.R >= b0) & (myData.R < b1)
label = f'${b0} ≤ R < {b1}$'
ax1.scatter(myData.X[filter], myData.Y[filter], s=20, c=col, marker='_', label=label)
ax1.legend()
plt.show()
或者,可以使用 pandas 的cut
來創建類別,而 seaborn 的功能(例如其hue
參數)可以進行着色並自動創建圖例。
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
N = 500
myData = pd.DataFrame({'X': np.round( np.random.uniform(-1000, 1000, N),-2), 'Y': np.random.uniform(-800, 800, N)})
myData['R'] = np.sqrt(myData.X ** 2 + myData.Y ** 2)
fig, ax1 = plt.subplots()
bounds = [150, 200, 400, 600]
colors = ['teal', 'red', 'green', 'purple', 'gold']
hues = pd.cut(myData.R, [0]+bounds+[2000], right=False)
sns.scatterplot(myData.X, myData.Y, hue=hues, hue_order=hues.cat.categories, palette=colors, s=20, marker='_', ax=ax1)
plt.show()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.