简体   繁体   中英

clustering multiple categorical columns to make time series line plot in matplotlib

I am interested in how the COVID pandemic is affecting meat processing plants across the country. I retrieved NYT COVID data by county level and statistical data from the food agency. Here I am exploring how COVID cases are surging in counties where major food processing plants are located because more sick employees in plants might bring negative impacts to the business. In my first attempt, I figured out moving average time series plots where COVID new cases vs 7 days rolling mean along the date.

But, I think it would be more efficient I could replace the graph which represents num-emp and new-cases by counties in the for loop. To achieve this, I think it would be better to cluster them by company level and expand them into multiple graphs to prevent the lines from overlapping and becoming to difficult to see. I am not sure how to achieve this from my current attempt. Can anyone suggest a possible ways of doing this in matplotlib? Any idea?

my current attempt :

Here is the reproducible data in this gist that I used in my experiment:

import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import seaborn as sns from datetime import timedelta, datetime

df = pd.read_csv("https://gist.githubusercontent.com/jerry-shad/7eb2dd4ac75034fcb50ff5549f2e5e21/raw/477c07446a8715f043c9b1ba703a03b2f913bdbf/covid_tsdf.csv")
df.drop(['Unnamed: 0', 'fips', 'non-fed-slaughter', 'fed-slaughter', 'total-slaughter', 'mcd-asl'], axis=1, inplace=True)
for ct in df['county_state'].unique():
    dd = df[df['county_state'] == ct].groupby(['county_state', 'date', 'est'])[['cases','new_cases']].sum().unstack().reset_index()
    dd.columns= ['county_state','date', 'cases', 'new_cases']
    dd['date'] = pd.to_datetime(dd['date'])
    dd['rol7'] = dd[['date','new_cases']].rolling(7).mean()
    fig = plt.figure(figsize=(8,6),dpi=144)
    ax = fig.add_subplot(111)
    colors = sns.color_palette()
    ax2 = ax.twinx()
    ax = sns.lineplot('date', 'rol7', data=dd, color=colors[1], ax=ax)
    ax2 = sns.lineplot('date', 'cases', data=dd, color=colors[0], ax=ax2)
    ax.set_xlim(dd.date.min(), dd.date.max())
    fig.legend(['rolling7','cases'],loc="upper left", bbox_to_anchor=(0.01, 0.95), bbox_transform=ax.transAxes)
    ax.grid(axis='both', lw=0.5)
    locator = mdates.AutoDateLocator()
    ax.xaxis.set_major_locator(locator)
    fig.autofmt_xdate(rotation=45)
    ax.set(title=f'US covid tracking in meat processing plants by county - Linear scale')
    plt.show()

here is my current output:

在此处输入图片说明

but this output is not quite significant to understand how food processing company' is affected by COVID because of infected employees. To make this more visually accessible to understand, I think we can replace the two graphs with num-emp and newly infected case new_cases and draw the counties we need in the loop process. At that point, it would be better to cluster them by company characteristics, etc. and expand them into multiple graphs to prevent the lines from overlapping and becoming difficult to see. I want to make EDA that provides this sort of information visually. Can anyone suggest possible ways of doing this with matplotlib ? Any thoughts? Thanks!

  • There were a couple of issues, I've made inline notations
  • The main issue was in the .groupby
    • The data is already selected by 'country_state' so there's no need to groupby it
    • Only reset_index(level=1) , keep date in the index for rolling
    • .unstack() was creating multi-level column names.
  • Set ci=None for plotting.
  • It doesn't make sense to use 'num-emp' as a metrics. It's constant across time.
    • If you want to see the plot, swap 'cases' in the loop, for 'num-emp' .
  • I think the best way to see the impact of COVID on a given company, is to find a dataset with revenue.
  • Because food processing plants are considered critical infrastructure, there probably won't be much change in their head count, and anyone who is sick, is probably on sick leave vs. termination.
import pandas as pd
import matplotlib.pyplot as plt

url = 'https://gist.githubusercontent.com/jerry-shad/7eb2dd4ac75034fcb50ff5549f2e5e21/raw/477c07446a8715f043c9b1ba703a03b2f913bdbf/covid_tsdf.csv'

# load the data and parse the dates
df = pd.read_csv(url, parse_dates=['date'])

# drop unneeded columns
df.drop(['Unnamed: 0', 'fips', 'non-fed-slaughter', 'fed-slaughter', 'total-slaughter', 'mcd-asl'], axis=1, inplace=True)

for ct in df['county_state'].unique():
    
    # groupby has been updated: no need for county becasue they're all the same, given the loop; keep date in the index for rolling
    dd = df[df['county_state'] == ct].groupby(['date', 'est', 'packer'])[['cases','new_cases']].sum().reset_index(level=[1, 2])
    dd['rol7'] = dd[['new_cases']].rolling(7).mean()

    colors = sns.color_palette()
    
    fig, ax = plt.subplots(figsize=(8, 6), dpi=144)
    ax2 = ax.twinx()
    
    sns.lineplot(dd.index, 'rol7', ci=None, data=dd, color=colors[1], ax=ax)  # date is in the index
    sns.lineplot(dd.index, 'cases', ci=None, data=dd, color=colors[0], ax=ax2)  # date is in the index
    
    ax.set_xlim(dd.index.min(), dd.index.max())  # date is in the index
    fig.legend(['rolling7','cases'], loc="upper left", bbox_to_anchor=(0.01, 0.95), bbox_transform=ax.transAxes)
    
    # set y labels
    ax.set_ylabel('7-day Rolling Mean')
    ax2.set_ylabel('Current Number of Cases')
    
    ax.grid(axis='both', lw=0.5)
    locator = mdates.AutoDateLocator()
    ax.xaxis.set_major_locator(locator)
    fig.autofmt_xdate(rotation=45)
    
    # create a dict for packer and est
    vals = dict(dd[['packer', 'est']].reset_index(drop=True).drop_duplicates().values.tolist())
    
    # create a custom string from vals, for the title
    insert = ', '.join([f'{k}: {v}' for k, v in vals.items()])

#     ax.set(title=f'US covid tracking in meat processing plants for {ct} \nPacker: {", ".join(dd.packer.unique())}\nEstablishments: {", ".join(dd.est.unique())}')

    # alternate title based on comment request
    ax.set(title=f'US covid tracking in meat processing plants for {ct} \n{insert}')
    
    plt.savefig(f'images/{ct}.png')  # save files by ct name to images directory
    plt.show()
    plt.close()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM