简体   繁体   中英

matplotlib: change axis ticks of ndim histogram plotted with seaborn.heatmap

Motivation:

I'm trying to visualize a dataset of many n-dimensional vectors (let's say i have 10k vectors with n=300 dimensions). What i'd like to do is calculate a histogram for each of the n dimensions and plot it as a single line in a bins*n heatmap.

So far i've got this:

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns

# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)

def ndhist(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    print(dims)
    print(limits)

ndhist(vectors)

This generates the following output:

300
(-6.538069472429366, 6.52159540162285)

坏轴刻度

Problem / Question:

How can i change the axes ticks?

  • for the y-axis i'd like to simply change this back to matplotlib's default, so it picks nice ticks like 0, 50, 100, ..., 250 (bonus points for 299 or 300 )
  • for the x-axis i'd like to convert the shown bin indices into the bin (left) boundaries, then, as above, i'd like to change this back to matplotlib's default selection of some "nice" ticks like -5, -2.5, 0, 2.5, 5 (bonus points for also including the actual limits -6.538, 6.522 )

Own solution attempts:

I've tried many things like the following already:

def ndhist_axlabels(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists, yticklabels=False, xticklabels=False)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    #plt.xticks(np.linspace(*limits, len(bins)), bins)
    plt.xticks(range(len(bins)), bins)
    axes.xaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    plt.yticks(range(dims+1), range(dims+1))
    axes.yaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    print(dims)
    print(limits)

ndhist_axlabels(vectors)

更糟糕的轴滴答声

As you can see however, the axes labels are pretty wrong. My guess is that the extent or limits are somewhere stored in the original axis, but lost when switching back to the AutoLocator . Would greatly appreciate a nudge in the right direction.

Maybe you're overthinking this. To plot image data, one can use imshow and get the ticking and formatting for free.

import numpy as np
from matplotlib import pyplot as plt

# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)

def ndhist(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]

    for dim in range(dims):
        h, _ = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)

    fig, ax = plt.subplots(figsize=(16, 9))

    extent = [limits[0], limits[-1], hists.shape[0]-0.5, -0.5]  
    im = ax.imshow(hists, extent=extent, aspect="auto")
    fig.colorbar(im)

    ax.set(ylabel='dimensions', xlabel='values')

ndhist(vectors)
plt.show()

在此处输入图片说明

If you read the docs , you will notice that the xticklabels / yticklabels arguments are overloaded, such that if you provide an integer instead of a string, it will interpret the argument as xtickevery / ytickevery and place ticks only at the corresponding locations. So in your case, seaborn.heatmap(hists, yticklabels=50) fixes your y-axis problem.

在此处输入图片说明

Regarding your xtick labels, I would simply provide them explictly:

xtickevery = 50 
xticklabels = ['{:.1f}'.format(b) if ii%xtickevery == 0 else '' for ii, b in enumerate(bins)]
sns.heatmap(hists, yticklabels=50, xticklabels=xticklabels)

在此处输入图片说明

Finally came up with a version that works for me for now and uses AutoLocator based on some simple linear mapping...

def ndhist(vectors, bins=1000, title=None):
    t = time.time()
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bs = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)

    fig = plt.figure(figsize=(16, 12))
    sns.heatmap(
        hists,
        yticklabels=50,
        xticklabels=False
    )

    axes = fig.gca()
    axes.set(
        ylabel=f'dimensions ({dims} total)',
        xlabel=f'values (min: {limits[0]:.4g}, max: {limits[1]:.4g}, {bins} bins)',
        title=title,
    )

    def val_to_idx(val):
        # calc (linearly interpolated) index loc for given val
        return bins*(val - limits[0])/(limits[1] - limits[0])
    xlabels = [round(l, 3) for l in limits] + [
        v for v in matplotlib.ticker.AutoLocator().tick_values(*limits)[1:-1]
    ]
    # drop auto-gen labels that might be too close to limits
    d = (xlabels[4] - xlabels[3])/3
    if (xlabels[1] - xlabels[-1]) < d:
        del xlabels[-1]
    if (xlabels[2] - xlabels[0]) < d:
        del xlabels[2]
    xticks = [val_to_idx(val) for val in xlabels]
    axes.set_xticks(xticks)
    axes.set_xticklabels([f'{l:.4g}' for l in xlabels])

    plt.show()
    print(f'histogram generated in {time.time() - t:.2f}s')

ndhist(np.random.randn(100000, 300), bins=1000, title='randn')

历史

Thanks to Paul for his answer giving me the idea.

If there's an easier or more elegant solution, i'd still be interested though.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM