简体   繁体   中英

How to speed up seaborn heatmaps

Struggling with an issue of searborn facetgrid heatmaps slowness. I have expanded my data set from previous problem and thanks to @Diziet Asahi to provide a solution to facetgrid issue.

Now, I have 20x20 grid with 625 points in each grid to be mapped. It takes forever to get an output for even one layer little1 . I have a thousands of little layers in real data.

My code is along the lines:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

ltt= ['little1']

methods=["m" + str(i) for i in range(1,21)]
labels=["l" + str(i) for i in range(1,20)]

times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
data['nw_score'] = np.random.choice([0,1],data.shape[0])

data outputs to:

Out[36]: 
            ltt method labels  dtsi  rtsi  nw_score
0       little1     m1     l1     0     0         1
1       little1     m1     l1     0     4         0
2       little1     m1     l1     0     8         0
3       little1     m1     l1     0    12         1
4       little1     m1     l1     0    16         0
        ...    ...    ...   ...   ...       ...
237495  little1    m20    l19    96    80         0
237496  little1    m20    l19    96    84         1
237497  little1    m20    l19    96    88         0
237498  little1    m20    l19    96    92         0
237499  little1    m20    l19    96    96         1

[237500 rows x 6 columns]

Plotting and defining facet function:

labels_fill = {0:'red',1:'blue'}

del methods
del labels

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)

for lt in data.ltt.unique():
    with sns.plotting_context(font_scale=5.5):
        g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
        g = g.map_dataframe(facet)
        g.add_legend()
        g.set_titles(template="")

        for ax,method in zip(g.axes[0,:],data.method.unique()):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(g.axes[:,0],data.labels.unique()):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        g.fig.suptitle(lt, fontweight='bold', fontsize=12)
        g.fig.tight_layout()
        g.fig.subplots_adjust(top=0.8) # make some room for the title

        g.savefig(lt+'.png', dpi=300)
    

在此处输入图片说明

I stopped the code after some time and we can see that grids are being filled one-by-one which is time-consuming. Generating this heatmap is unbearably slow.

I wonder is there a better way to speed up the process ?

Seaborn is slow. If you use matplotlib instead of seaborn, you arrive at half a minute or so per figure. This is still long, but given that you produce a ~12000x12000 pixel figure, it's kind of expected.

import time
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

ltt= ['little1']

methods=["m" + str(i) for i in range(1,21)]
#methods=['method 1', 'method 2', 'method 3', 'method 4']
#labels = ['label1','label2']
labels=["l" + str(i) for i in range(1,20)]

times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])

labels_fill = {0:'red',1:'blue'}

del methods
del labels


cmap=ListedColormap(['red', 'blue'])

def facet(data, ax):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    ax.imshow(data, cmap=cmap)

def myfacetgrid(data, row, col, figure=None):
    rows = np.unique(data[row].values)  
    cols = np.unique(data[col].values)

    fig, axs = plt.subplots(len(rows), len(cols), 
                            figsize=(2*len(cols)+1, 2*len(rows)+1))


    for i, r in enumerate(rows):
        row_data = data[data[row] == r]
        for j, c in enumerate(cols):
            this_data = row_data[row_data[col] == c]
            facet(this_data, axs[i,j])
    return fig, axs


for lt in data.ltt.unique():

    with sns.plotting_context(font_scale=5.5):
        t = time.time()
        fig, axs = myfacetgrid(data[data.ltt==lt], row="labels", col="method")
        print(time.time()-t)
        for ax,method in zip(axs[0,:],data.method.unique()):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(axs[:,0],data.labels.unique()):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        print(time.time()-t)
        fig.suptitle(lt, fontweight='bold', fontsize=12)
        fig.tight_layout()
        fig.subplots_adjust(top=0.8) # make some room for the title
        print(time.time()-t)
        fig.savefig(lt+'.png', dpi=300)
        print(time.time()-t)

Here the timing divide into ~6 seconds creating the facetgrid, ~7 seconds optimizing the grid layout (via tight_layout - think about leaving it out!), and 15 seconds drawing the figure.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM