简体   繁体   中英

How to make a line plot from a dataframe with multiple categorical columns in matplotlib

I want to make line chart for the different categories where one is a different country, and one is a different country for weekly based line charts. Initially, I was able to draft line plots using seaborn but it is not quite handy like setting its label, legend, color palette and so on. I am wondering is there any way to easily reshape this data with multiple categorical variables and render line charts. In initial attempt, I tried seaborn.relplot but it is not easy to tune its parameter and hard to customize the resulted plot. Can anyone point me to any efficient way to reshape dataframe with multiple categorical columns and render a clear line chart? Any thoughts?

reproducible data & my attempt :

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
dff = pd.read_csv(url, parse_dates=['weekly'])
dff.drop('Unnamed: 0', axis=1, inplace=True)

df2_bf = dff.groupby(['destination', 'weekly'])['FCF_Beef'].sum().unstack()
df2_bf = df2_bf.fillna(0)
mm = df2_bf.T
mm.columns.name = None
mm = mm[~(mm.isna().sum(1)/mm.shape[1]).gt(0.9)].fillna(0)

#Total sum per column: 
mm.loc['Total',:]= mm.sum(axis=0)
mm1 = mm.T
mm1 = mm1.nlargest(6, columns=['Total'])
mm1.drop('Total', axis=1, inplace=True)
mm2 = mm1.T
mm2.reset_index(inplace=True)
mm2['weekly'] = pd.to_datetime(mm2['weekly'])

mm2['year'] = mm2['weekly'].dt.year
mm2['week'] = mm2['weekly'].dt.isocalendar().week
df = mm2.melt(id_vars=['weekly','week','year'], var_name='country')

df_ = df.groupby(['country', 'year', 'week'], as_index=False)['value'].sum()
sns.relplot(data=df_, x='week', y='value', hue='year', row='country', kind='line', height=6, aspect=2, facet_kws={'sharey': False, 'sharex': False}, sizes=(20, 10))

current plot

this is one of current plot that I made with seaborn.relplot

structure of plot is okay for me, but in seaborn.replot , it is hard to tune parameter and it is as flexible as using matplotlib . Also, I realized that the way of aggregating my data is not very efficient. I think there might be a shortcut to make the above code snippet more efficient like:

plt_data = []
for i in dff.loc[:, ['FCF_Beef','FCF_Beef']]:
    ...

but doing this way I faced a couple of issues to make the right plot. Can anyone point me out how to make this simple and efficient in order to make the expected line chart with matplotlib? Does anyone know any better way of doing this? Any idea? Thanks

desired output

In my desired plot, first I need to iterate list of countries, where each country has one subplot, in each subplot, x-axis shows 52 weeks and y-axis shows weeklyExport amount of different years for each country. Here is draft plot that I made with seaborn.relplot .

note that, I don't like the output from seaborn.relplot , so I am wondering how can I make above attempt more efficient with matplotlib attempt. Any idea?

  • As requested by the OP, following is an iterative way to plot the data.
  • The following example plots each year, for a given 'destination' in a single figure
  • This is similar to the answer for this question .
import pandas as pd
import matplotlib.pyplot as plt

# load the data
url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
df = pd.read_csv(url, parse_dates=['weekly'], usecols=range(1, 6))

# groupby destination and iterate through for plotting
for g, d in df.groupby(['destination']):

    # create the figure
    fig, ax = plt.subplots(figsize=(7, 4))
    
    # add lines for specific years
    for year in d.weekly.dt.year.unique():
        data = d[d.weekly.dt.year == year].copy()  # select the data from d, by year
        data['week'] = data.weekly.dt.isocalendar().week  # create a week column
        data.sort_values('weekly', inplace=True)
        display(data.head())  # display is for jupyter, if it causes an error, use pring
        data.plot(x='week', y='FCF_Beef', ax=ax, label=year)
    
    plt.show()
  • Single sample plot

在此处输入图像描述

  • If we look at the tail of one of the dataframes, data.weekly.dt.isocalendar().week as putting the last day of the year as week 1 , so a line is drawn back to the last data point being placed at week 1.
  • This function rests on datetime.datetime(2018, 12, 31).isocalendar() and is the expected behavior from the datetime module, as per this closed pandas bug .

在此处输入图像描述

  • Removing the last row with .iloc[:-1, :] , is a work around
  • Alternatively, replace data['week'] = data.weekly.dt.isocalendar().week with data['week'] = data.weekly.dt.strftime('%W').astype('int')
data.iloc[:-1, :].plot(x='week', y='FCF_Beef', ax=ax, label=year)

在此处输入图像描述

Updated with all code from OP

# load the data
url = 'https://gist.githubusercontent.com/adamFlyn/cb0553e009933574ac7ec3109ffb5140/raw/a277bc00dc08e526a7d5b7ead5425905f7206bfa/export.csv'
dff = pd.read_csv(url, parse_dates=['weekly'], usecols=range(1, 6))

df2_bf = dff.groupby(['destination', 'weekly'])['FCF_Beef'].sum().unstack()
df2_bf = df2_bf.fillna(0)
mm = df2_bf.T
mm.columns.name = None
mm = mm[~(mm.isna().sum(1)/mm.shape[1]).gt(0.9)].fillna(0)

#Total sum per column: 
mm.loc['Total',:]= mm.sum(axis=0)
mm1 = mm.T
mm1 = mm1.nlargest(6, columns=['Total'])
mm1.drop('Total', axis=1, inplace=True)
mm2 = mm1.T
mm2.reset_index(inplace=True)
mm2['weekly'] = pd.to_datetime(mm2['weekly'])

mm2['year'] = mm2['weekly'].dt.year
mm2['week'] = mm2['weekly'].dt.strftime('%W').astype('int')
df = mm2.melt(id_vars=['weekly','week','year'], var_name='country')

# groupby destination and iterate through for plotting
for g, d in df.groupby(['country']):

    # create the figure
    fig, ax = plt.subplots(figsize=(7, 4))
    
    # add lines for specific years
    for year in d.weekly.dt.year.unique():
        data = d[d.weekly.dt.year == year].copy()  # select the data from d, by year
        data.sort_values('weekly', inplace=True)
        display(data.head())  # display is for jupyter, if it causes an error, use pring
        data.plot(x='week', y='value', ax=ax, label=year, title=g)
    
    plt.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM