简体   繁体   中英

matplotlib FuncAnimation - make sure legend and plot lines have the same colors?

Consider a Pandas dataframe with multiple columns, each column a country name, and multiple rows, each row a date. The cells are data about countries, which vary in time. This is the CSV:

https://pastebin.com/bJbDz7ei

I want to make a dynamic plot (animation) in Jupyter that shows how the data evolves in time. Out of all countries in the world, I only want to show the top 10 countries at any given time. So the countries shown in the graph may change from time to time (because the top 10 is evolving).

I also want to maintain consistency in terms of colors. Only 10 countries are shown at any time, and some countries appear and disappear almost continuously, but the color for any country should not change throughout the animation. The color for any country should stick from start to finish.

This is the code I have ( EDIT : now you can copy/paste the code into Jupyter and it works out of the box, so you can easily see the bug I'm talking about):

import pandas as pd
import requests
import os
from matplotlib import pyplot as plt
import matplotlib.animation as ani

rel_big_file = 'rel_big.csv'
rel_big_url = 'https://pastebin.com/raw/bJbDz7ei'

if not os.path.exists(rel_big_file):
    r = requests.get(rel_big_url)
    with open(rel_big_file, 'wb') as f:
        f.write(r.content)

rel_big = pd.read_csv(rel_big_file, index_col='Date')

# history of top N countries
champs = []
# frame draw function
def animate_graph(i=int):
    N = 10
    # get current values for each country
    last_index = rel_big.index[i]
    # which countries are top N in last_index?
    topN = rel_big.loc[last_index].sort_values(ascending=False).head(N).index.tolist()
    # if country not already in champs, add it
    for c in topN:
        if c not in champs:
            champs.append(c)
    # pull a standard color map from matplotlib
    cmap = plt.get_cmap("tab20")
    # draw legend
    plt.legend(topN)
    # make a temporary dataframe with only top N countries
    rel_plot = rel_big[topN].copy(deep=True)
    # plot temporary dataframe
    p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)
    # set color for each country based on index in champs
    for i in range(0, N):
        p[i].set_color(cmap(champs.index(topN[i]) % 20))

%matplotlib notebook
fig = plt.figure(figsize=(10, 6))
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
# x ticks get too crowded, limit their number
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
animator = ani.FuncAnimation(fig, animate_graph, interval = 333)
plt.show()

It does the job - somewhat. I store the top countries in the champs list, and I assign colors based on the index of each country in champs. But only the color of the plotted lines is assigned correctly, based on the index in champs.

The color in the legend is assigned in a rigid manner, first country in the legend always gets the same color, second country in the legend always gets a certain color, etc, and basically the color of each country in the legend varies throughout the animation when countries move up and down in the legend.

The colors of the plotted lines obey the index in champs. The colors of countries in the legend are based on the order within the legend. This is not what I want.

How do I assign the color for each country in the legend in a way that matches the plot lines?

在此处输入图像描述 Here is My solution:

I deleted your code which generates the colors and set a new working one:

First, I initialized every country with his own unique color in a dictionary:

# initializing fixed color to all countries
colorsCountries = {}
for country in rel_big.columns:
    colorsCountries[country] = random.choice(list(mcd.CSS4_COLORS.keys()))

then i replaced this:

# plot temporary dataframe
p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)

with this:

# plot temporary dataframe
for keyIndex in rel_plot[:i].keys() :
    p = plt.plot(rel_plot[:i].index,rel_plot[:i][keyIndex].values,color=colorsCountries[keyIndex])

and then added a code that update matplotlib legend label and colors

leg = plt.legend(topN)
for line, text in zip(leg.get_lines(), leg.get_texts()):
    line.set_color(colorsCountries[text.get_text()])

don't forgot to add the imports:

import matplotlib._color_data as mcd
import random

Here is the complete suggested solution:

import pandas as pd
import requests
import os
from matplotlib import pyplot as plt
import matplotlib.animation as ani
import matplotlib._color_data as mcd
import random

rel_big_file = 'rel_big.csv'
rel_big_url = 'https://pastebin.com/raw/bJbDz7ei'

if not os.path.exists(rel_big_file):
    r = requests.get(rel_big_url)
    with open(rel_big_file, 'wb') as f:
        f.write(r.content)

rel_big = pd.read_csv(rel_big_file, index_col='Date')

# history of top N countries
champs = []
# initializing fixed color to all countries
colorsCountries = {}
for country in rel_big.columns:
    colorsCountries[country] = random.choice(list(mcd.CSS4_COLORS.keys()))
# frame draw function
def animate_graph(i=int):
    N = 10
    # get current values for each country
    last_index = rel_big.index[i]
    # which countries are top N in last_index?
    topN = rel_big.loc[last_index].sort_values(ascending=False).head(N).index.tolist()
    # if country not already in champs, add it
    for c in topN:
        if c not in champs:
            champs.append(c)
    # pull a standard color map from matplotlib
    cmap = plt.get_cmap("tab20")
    # draw legend
    plt.legend(topN)
    # make a temporary dataframe with only top N countries
    rel_plot = rel_big[topN].copy(deep=True)
    # plot temporary dataframe
    #### Removed Code
    #p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)
    #### Removed Code
    for keyIndex in rel_plot[:i].keys() :
        p = plt.plot(rel_plot[:i].index,rel_plot[:i][keyIndex].values,color=colorsCountries[keyIndex])
    # set color for each country based on index in champs
    #### Removed Code
    #for i in range(0, N):
        #p[i].set_color(cmap(champs.index(topN[i]) % 20))
    #### Removed Code
    leg = plt.legend(topN)
    for line, text in zip(leg.get_lines(), leg.get_texts()):
        line.set_color(colorsCountries[text.get_text()])

%matplotlib notebook
fig = plt.figure(figsize=(10, 6))
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
# x ticks get too crowded, limit their number
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
animator = ani.FuncAnimation(fig, animate_graph, interval = 333)
plt.show()

The answer by ZINE Mahmoud is great. I only changed one thing - I wanted a deterministic allocation of colors every run, so instead of the random method I assigned colors to countries like this:

colorsCountries = {}
colorPalette = list(mcd.CSS4_COLORS.keys())
for country in rel_big.columns:
    colorsCountries[country] = colorPalette[rel_big.columns.tolist().index(country) % len(colorPalette)]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM