简体   繁体   中英

Returning matplotlib plot/figure from a function and saving it later

As the title states I want to return a plt or figure (still not sure what the difference between the two things are) using matplotlib. The main idea behind it is so I can save the plt/figure later.

import seaborn as sns
from matplotlib import pyplot as plt
def graph(df, id):
       # size of the graph
        xlims = (-180, 180)
        ylims = (-180, 180)

        # dictate the colors of the scatter plot based on the grouping of hot or cold
        color_dict = {'COLD': 'blue',
                      'HOT': 'red'}
        title_name = f"{id}"
        ax = sns.scatterplot(data=df, hue='GRP', x='X_GRID', y='Y_GRID',
                             legend=False, palette=color_dict)
        ax.set_title(title_name)
        ax.set(xlim=xlims)
        ax.set(ylim=ylims)
        if show_grid:
            # pass in the prev graph so I can overlay grid
            ax = self.__get_grid( ax)
        circle1 = plt.Circle(xy=(0, 0), radius=150, color='black', fill=False, zorder=3)
        ax.add_patch(circle1)
        ax.set_aspect('equal')
        plt.axis('off')
        plt.savefig(title_name + '_in_ftn.png')
        fig = plt.figure()
        plt.clf()
        return (fig, title_name + '.png')

plots = []    
# dfs is just a tuple of df, id for example purposes
for df, id in dfs:
    plots.append(graph(df, id))

for plot, file_name in plots:
    plot.savefig(file_name)
    plot.clf()

When using plot.savefig(filename) it saves, but the saved file is blank which is wrong. Am I not properly returning the object I want to save? If not what should I return to be able to save it?

I kind of having it work, but not really. I am currently saving two figures for testing purposes. For some reason when I use the fig=plt.figure() and saving it outside the function the title of the figure and the filename are different (even though they should be the same since the only difference is.png)

However, when saving it inside the function the title name of the figure and the filename name are the same.

You code has multiple issues that I'll try to discuss here:

Your confusion around plt

First of all, there is no such thing as "a plt". plt is the custom name you are giving to the matplotlib.pyplot module when you are importing it with the line import matplotlib.pyplot as plt . You are basically just renaming the module with an easy to type abbreviation. If you had just written import matplotlib , you would have to write matplotlib.pyplot.axis('off') instead of plt.axis('off') .

Mix of procedural and object oriented approach

You are using a mix of the procedural and object oriented approach for matplotlib. Either you call your methods on the axis object ( ax ) or you can call functions that implicitly handle the axis and figure. For example you could either create and axis and then call ax.plot(...) or instead use plt.plot(...) , which implicitly creates the figure and axis. In your case, you mainly use the object oriented approach on the axis object that is returned by the seaborn function. However, you should use ax.axis('off') instead of plt.axis('off') .

You create a new blank figure

When you are calling the seaborn function sns.scatterplot , you are implicitly creating a matplotlib figure and axis object. You catch that axis object in the variable ax . You then use plt.savefig to save your image in the function, which works by implicitly getting the figure corresponding to the currently used axis object. However, you are then creating a new figure by calling fig = plt.figure() , which is of course blank, and then returning it. What you should do, is getting the figure currently used by the axis object you are working with. You can get it by calling fig = plt.gcf() (which stands for "get current figure") and would be the procedural approach, or better use fig = ax.get_figure()

What you should do instead is something like this:

import seaborn as sns
from matplotlib import pyplot as plt
def graph(df, id):
       # size of the graph
        xlims = (-180, 180)
        ylims = (-180, 180)

        # dictate the colors of the scatter plot based on the grouping of hot or cold
        color_dict = {'COLD': 'blue',
                      'HOT': 'red'}
        title_name = f"{id}"
        ax = sns.scatterplot(data=df, hue='GRP', x='X_GRID', y='Y_GRID',
                             legend=False, palette=color_dict)
        ax.set_title(title_name)
        ax.set(xlim=xlims)
        ax.set(ylim=ylims)
        if show_grid:
            # pass in the prev graph so I can overlay grid
            ax = self.__get_grid( ax)
        circle1 = plt.Circle(xy=(0, 0), radius=150, color='black', fill=False, zorder=3)
        ax.add_patch(circle1)
        ax.set_aspect('equal')
        ax.axis('off')
        fig = ax.get_figure()
        fig.savefig(title_name + '_in_ftn.png')
        return (fig, title_name + '.png')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM