简体   繁体   English

matplotlib:更改使用 seaborn.heatmap 绘制的 ndim 直方图的轴刻度

[英]matplotlib: change axis ticks of ndim histogram plotted with seaborn.heatmap

Motivation:动机:

I'm trying to visualize a dataset of many n-dimensional vectors (let's say i have 10k vectors with n=300 dimensions).我正在尝试可视化一个包含许多 n 维向量的数据集(假设我有 10k 个 n=300 维的向量)。 What i'd like to do is calculate a histogram for each of the n dimensions and plot it as a single line in a bins*n heatmap.我想要做的是为 n 维中的每一个计算直方图,并将其绘制为 bins*n 热图中的一条线。

So far i've got this:到目前为止,我有这个:

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns

# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)

def ndhist(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    print(dims)
    print(limits)

ndhist(vectors)

This generates the following output:这将生成以下输出:

300
(-6.538069472429366, 6.52159540162285)

坏轴刻度

Problem / Question:问题/问题:

How can i change the axes ticks?如何更改轴刻度?

  • for the y-axis i'd like to simply change this back to matplotlib's default, so it picks nice ticks like 0, 50, 100, ..., 250 (bonus points for 299 or 300 )对于 y 轴,我想简单地将其更改回 matplotlib 的默认值,因此它选择了0, 50, 100, ..., 250等不错的刻度( 299300奖励积分)
  • for the x-axis i'd like to convert the shown bin indices into the bin (left) boundaries, then, as above, i'd like to change this back to matplotlib's default selection of some "nice" ticks like -5, -2.5, 0, 2.5, 5 (bonus points for also including the actual limits -6.538, 6.522 )对于 x 轴,我想将显示的 bin 索引转换为 bin(左)边界,然后,如上所述,我想将其改回 matplotlib 的一些“不错”刻度的默认选择,例如-5, -2.5, 0, 2.5, 5 (还包括实际限制-6.538, 6.522奖励积分)

Own solution attempts:自己的解决方案尝试:

I've tried many things like the following already:我已经尝试过很多类似的事情:

def ndhist_axlabels(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bins = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)
    fig = plt.figure(figsize=(16, 9))
    sns.heatmap(hists, yticklabels=False, xticklabels=False)
    axes = fig.gca()
    axes.set(ylabel='dimensions', xlabel='values')
    #plt.xticks(np.linspace(*limits, len(bins)), bins)
    plt.xticks(range(len(bins)), bins)
    axes.xaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    plt.yticks(range(dims+1), range(dims+1))
    axes.yaxis.set_major_locator(matplotlib.ticker.AutoLocator())
    print(dims)
    print(limits)

ndhist_axlabels(vectors)

更糟糕的轴滴答声

As you can see however, the axes labels are pretty wrong.但是,正如您所看到的,轴标签非常错误。 My guess is that the extent or limits are somewhere stored in the original axis, but lost when switching back to the AutoLocator .我的猜测是范围或限制存储在原始轴中的某处,但在切换回AutoLocator时丢失。 Would greatly appreciate a nudge in the right direction.非常感谢在正确方向上的推动。

Maybe you're overthinking this.也许你想多了。 To plot image data, one can use imshow and get the ticking and formatting for free.要绘制图像数据,可以使用imshow并免费获得imshow和格式。

import numpy as np
from matplotlib import pyplot as plt

# sample data:
vectors = np.random.randn(10000, 300) + np.random.randn(300)

def ndhist(vectors, bins=500):
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]

    for dim in range(dims):
        h, _ = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)

    fig, ax = plt.subplots(figsize=(16, 9))

    extent = [limits[0], limits[-1], hists.shape[0]-0.5, -0.5]  
    im = ax.imshow(hists, extent=extent, aspect="auto")
    fig.colorbar(im)

    ax.set(ylabel='dimensions', xlabel='values')

ndhist(vectors)
plt.show()

在此处输入图片说明

If you read the docs , you will notice that the xticklabels / yticklabels arguments are overloaded, such that if you provide an integer instead of a string, it will interpret the argument as xtickevery / ytickevery and place ticks only at the corresponding locations.如果您阅读文档,您会注意到xticklabels / yticklabels参数已重载,因此如果您提供整数而不是字符串,它会将参数解释为xtickevery / ytickevery并仅在相应位置放置刻度。 So in your case, seaborn.heatmap(hists, yticklabels=50) fixes your y-axis problem.所以在你的情况下, seaborn.heatmap(hists, yticklabels=50)修复了你的 y 轴问题。

在此处输入图片说明

Regarding your xtick labels, I would simply provide them explictly:关于您的 xtick 标签,我只想明确地提供它们:

xtickevery = 50 
xticklabels = ['{:.1f}'.format(b) if ii%xtickevery == 0 else '' for ii, b in enumerate(bins)]
sns.heatmap(hists, yticklabels=50, xticklabels=xticklabels)

在此处输入图片说明

Finally came up with a version that works for me for now and uses AutoLocator based on some simple linear mapping...终于想出了一个现在适合我的版本,并使用基于一些简单线性映射的AutoLocator ......

def ndhist(vectors, bins=1000, title=None):
    t = time.time()
    limits = (vectors.min(), vectors.max())
    hists = []
    dims = vectors.shape[1]
    for dim in range(dims):
        h, bs = np.histogram(vectors[:, dim], bins=bins, range=limits)
        hists.append(h)
    hists = np.array(hists)

    fig = plt.figure(figsize=(16, 12))
    sns.heatmap(
        hists,
        yticklabels=50,
        xticklabels=False
    )

    axes = fig.gca()
    axes.set(
        ylabel=f'dimensions ({dims} total)',
        xlabel=f'values (min: {limits[0]:.4g}, max: {limits[1]:.4g}, {bins} bins)',
        title=title,
    )

    def val_to_idx(val):
        # calc (linearly interpolated) index loc for given val
        return bins*(val - limits[0])/(limits[1] - limits[0])
    xlabels = [round(l, 3) for l in limits] + [
        v for v in matplotlib.ticker.AutoLocator().tick_values(*limits)[1:-1]
    ]
    # drop auto-gen labels that might be too close to limits
    d = (xlabels[4] - xlabels[3])/3
    if (xlabels[1] - xlabels[-1]) < d:
        del xlabels[-1]
    if (xlabels[2] - xlabels[0]) < d:
        del xlabels[2]
    xticks = [val_to_idx(val) for val in xlabels]
    axes.set_xticks(xticks)
    axes.set_xticklabels([f'{l:.4g}' for l in xlabels])

    plt.show()
    print(f'histogram generated in {time.time() - t:.2f}s')

ndhist(np.random.randn(100000, 300), bins=1000, title='randn')

历史

Thanks to Paul for his answer giving me the idea.感谢保罗的回答给了我这个想法。

If there's an easier or more elegant solution, i'd still be interested though.如果有更简单或更优雅的解决方案,我仍然会感兴趣。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM