简体   繁体   中英

Python, Matplotlib: Stack multiple heatmaps on top of each other along z-axis in 3D

I have two heatmaps that I want to display on top of each other in a 3D view. The heatmaps are plotted with imshow(), which correctly sets each element of the data as a colored square in the plot. However, when plotting the same heatmaps with plot_surface() each corner now instead represents each element. Note that the first figure has 15x15 squares and ticks at the middle of each square, and that the second figure only has 14x14 squares and ticks in each square's corner. Since I'm working with discrete data and am interested in each (x,y) combination, the second representation doesn't make sense.

How can I make it so that the heatmaps in 3D are displayed in the same way as the 2D heatmaps? That is, how can I plot a 3D plot that sets x and y ticks in the middle of each square, and correctly plots 15x15 elements? (Note that it's fine that the colors in the heatmaps currently differ from the 2D to 3D case)

Code

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


# Dummy data
X = range(2, 16+1)
Y = range(2, 16+1)
xs, ys = np.meshgrid(X, Y)

zs1 = np.random.rand(15,15)
zs2 = np.random.rand(15,15)

# Imshow 2D plot
_, (ax1, ax2) = plt.subplots(2,1)
plot = ax1.imshow(np.flip(zs1, 0), cmap=plt.cm.RdYlGn, interpolation='none', extent=[1.5, 16.5, 1.5, 16.5])
plot = ax2.imshow(np.flip(zs2, 0), cmap=plt.cm.RdYlGn, interpolation='none', extent=[1.5, 16.5, 1.5, 16.5])
plt.draw()

# Surface 3D plot
fig = plt.figure()
ax2 = Axes3D(fig)
plot = ax2.plot_surface(xs, ys, zs1, rstride=1, cstride=1,
                    antialiased=False, linewidth=0, cmap=plt.cm.RdYlGn)
plot = ax2.plot_surface(xs, ys, zs2 + 1000, rstride=1, cstride=1,
                    antialiased=False, linewidth=0, cmap=plt.cm.RdYlGn)


plt.show()

二维热图

The 2D heatmaps. Note that there are 15x15 elements and ticks in the middle of each square.

3D 热图

The 3D heatmaps. Note that there are only 14x14 elements and ticks in each square's corner. I want these to be display in the same way as the 2D heatmaps!!

I think you have understood the core problem: plot_surface is meant to plot surfaces, not tilted heatmaps. For example, you increase the z-range massively to "flatten" the two surfaces in 3 dimensions, as the surfaces then have values in the intervals [0, 1] and [1000, 1001], respectively.

Because plot_surfaces is meant for, well, surfaces, it interprets your samples as point estimates and then interpolates between point estimates to compute an estimate of the average surface height between the points. Hence a 15x15 array of point estimates results in 14x14 surfaces, and none of the colors match although you are applying the same colormap. I would recommend the famous/infamous essay "A pixel is not a little square! A pixel is not a little square! A pixel is not a little square! (A voxel is not a little cube!)" for further reading if this behaviour does not seem logical to you.

Having understood why plot_surfaces handles the data the way it does, it becomes clear that one solution would be to upsample your data:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Dummy data
zs1 = np.random.rand(15,15)
zs2 = np.random.rand(15,15)

# Imshow 2D plot
fig, (ax1, ax2) = plt.subplots(2,1)
plot = ax1.imshow(np.flip(zs1, 0), cmap=plt.cm.RdYlGn, interpolation='none', extent=[1.5, 16.5, 1.5, 16.5], vmin=0, vmax=1)
plot = ax2.imshow(np.flip(zs2, 0), cmap=plt.cm.RdYlGn, interpolation='none', extent=[1.5, 16.5, 1.5, 16.5], vmin=0, vmax=1)
plt.draw()

# Surface 3D plot
upsample_by = 20
X = np.linspace(2, 16, 15*upsample_by)
Y = np.linspace(2, 16, 15*upsample_by)
xs, ys = np.meshgrid(X, Y)
zs1 = np.repeat(np.repeat(zs1, upsample_by, axis=0), upsample_by, axis=1)
zs2 = np.repeat(np.repeat(zs2, upsample_by, axis=0), upsample_by, axis=1)

fig3 = plt.figure()
ax3 = Axes3D(fig3)
plot = ax3.plot_surface(xs, ys, zs1, rstride=1, cstride=1,
                        antialiased=False, linewidth=0, cmap=plt.cm.RdYlGn, vmin=0, vmax=1)
plot = ax3.plot_surface(xs, ys, zs2 + 1000, rstride=1, cstride=1,
                        antialiased=False, linewidth=0, cmap=plt.cm.RdYlGn, vmin=1000, vmax=1001)


plt.show()

在此处输入图像描述

在此处输入图像描述

I don't recommend doing this, as all in all, your solution is pretty hacky and my tweak only makes it worse. Personally, I would draw each individual pixel as a little square in 3D, appropriately colored. This matplotlib tutorial demonstrates how to add 2-D patches to axes with 3-D projections.

Based on this, we can write a little function that draws heatmaps in 3D at defined heights:

#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import mpl_toolkits.mplot3d.art3d as art3d


def tilted_heatmap_in_3d(arr, z, cmap=plt.cm.RdYlGn, ax=None):
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')

    for ii, row in enumerate(arr):
        for jj, value in enumerate(row):
            r = Rectangle((ii-0.5, jj-0.5), 1, 1, color=cmap(value))
            ax.add_patch(r)
            art3d.pathpatch_2d_to_3d(r, z=z, zdir="z")

    ax.set_xlim(-1, ii+1)
    ax.set_ylim(-1, jj+1)
    ax.set_zlim(0, 2*z)
    ax.get_figure().canvas.draw()


if __name__ == '__main__':

    tilted_heatmap_in_3d(np.random.rand(15, 15), z=5)
    plt.show()

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM