简体   繁体   中英

Iterate through DataFrame categorical columns to create subplots

I am trying to create a grid of Subplots for a predetermined x & y data. The functions should iterate through a pandas DataFrame, identify Categorical variables and plot the x & y data with a line for each level of a given categorial variable. The number of plots is equal to the number of Categorical variables, and the number of lines on each plot should be reflective of the number of categories for that variable.

I initially tried to group the Dataframe in a For loop on a given categorical variable, but I have had some mixed results. I think My issue is in how I am assigning what axis the lines are getting drawn on.


def grouping_for_graphs(df,x_col, y_col,category,func):
    '''
    funtion to group dataframe given a variable and 
    aggregation function

    '''
    X = df[x_col].name
    y = df[y_col].name
    category = df[category].name

    df_grouped = df.groupby([X, category])[y].apply(func)
    return df_grouped.reset_index()


# create a list of categorical variables to plot
cat_list = []
col_list = list(df.select_dtypes(include = ['object']).columns)

for col in col_list:
    if len(df[col].unique()) < 7:
        cat_list.append(col)


# create plots and axes
fig, axs = plt.subplots(2, 2, figsize=(30,24))
axs = axs.flatten()
# pick plot function
plot_func = plt.plot

# plot this
for ax, category in zip(axs, cat_list):
    df_grouped = grouping_for_graphs(df,x_col, y_col,category,agg_func)
    x_col = df_grouped.columns[0]
    y_col = df_grouped.columns[-1]
    category = str(list(df_grouped.columns.drop([x_lab, y_lab]))[0])
    for feature in list(df_grouped[category].unique()):
        X = df_grouped[df_grouped[category] == feature][x_col]
        y = df_grouped[df_grouped[category] == feature][y_col]
        ax.plot = plot_func(X,y)
        ax.set_xlabel(x_col)
        ax.set_ylabel(y_col)
        ax.set_title(feature)

Other than getting an error that ax.plot is a 'list' object and is not callable, all the lines drawn are put on the final plot of the subplots.

I am confused with your plot_func . Remove this and just directly plot using ax.plot(X, y) . The modified line is highlighted by a comment

fig, axs = plt.subplots(2, 2, figsize=(30,24))
axs = axs.flatten()

for ax, category in zip(axs, cat_list):
    df_grouped = grouping_for_graphs(df,x_col, y_col,category,agg_func)
    x_col = df_grouped.columns[0]
    y_col = df_grouped.columns[-1]
    category = str(list(df_grouped.columns.drop([x_lab, y_lab]))[0])
    for feature in list(df_grouped[category].unique()):
        X = df_grouped[df_grouped[category] == feature][x_col]
        y = df_grouped[df_grouped[category] == feature][y_col]
        ax.plot(X,y) # <--- Modified here
        ax.set_xlabel(x_col)
        ax.set_ylabel(y_col)
        ax.set_title(feature)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM