简体   繁体   中英

Plotting matplotlib subplots with functions

I am attemption to create a function to serve as a quick visual assessment for a normal distribution and to automate this for a whole dataframe. I want to create a no. of cols x 2 subplot (2 columns, each column of a dataframe a row) with the left plot being a histogram and the right a probability plot. I have written functions for each of these plots and they work fine, and the ax argument I have added can successfully plot them in a specific subplot coordinate. When I try to call these functions in a final function, intended to apply these to each column in a dataframe only the first histogram is returned and the rest of the plots empty.

Not sure where I am going wrong. See code for functions below. Note, no errors are returned:

#Histogram for normality
def normal_dist_hist(data, ax):
    #Format data for plotting
    #Included ax for subplot coordinate
    if data.isnull().values.any() == True:
        data.dropna(inplace=True)
    if data.dtypes == 'float64':
        data.astype('int64')
    #Plot distribution with Gaussian overlay
    mu, std = stats.norm.fit(data)
    ax.hist(data, bins=50, density=True, alpha=0.6, color='g')
    xmin, xmax = ax.get_xlim()
    x = np.linspace(xmin, xmax, 100)
    p = stats.norm.pdf(x, mu, std)
    ax.plot(x, p, 'k', linewidth=2)
    title = "Fit results: mu = %.2f,  std = %.2f" % (mu, std)
    ax.set_title(title)
    plt.show()

    #Probability plot
def normal_test_QQplots(data, ax):
    #added ax argument for specifying subplot coordinate, 
    data.dropna(inplace=True)
    probplt = stats.probplot(data,dist='norm',fit=True,plot=ax)
    plt.show()

def normality_report(df):
    fig, axes = plt.subplots(nrows=len(df.columns), ncols=2,figsize=(12,50))
    ax_y = 0
    for col in df.columns[1:]:
        ax_x = 0
        normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
        ax_x = 1
        normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
        ax_y += 1

Remove the plt.show() from your methods normal_dist_hist(...) and normal_test_QQplots(...) . Add plt.show() at the end of your normality_report(...) .

def normal_dist_hist(data, ax):
    ...
    plt.show() # Remove this

#Probability plot
def normal_test_QQplots(data, ax):
    ...
    plt.show() # Remove this

def normality_report(df):
    ...
    for col in df.columns[1:]:
        ax_x = 0
        normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
        ax_x = 1
        normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
        ax_y += 1
    plt.show() # Add it here.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM