简体   繁体   中英

How to plot 3D axis-origin figure using python

I want to plot a 3-D surface with the axis in the middle of the figure.

I use the following code to plot the figure:

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)

X, Y = np.meshgrid(x, y)

Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])


fig = plt.figure()
ax = plt.axes(projection='3d')
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                cmap='viridis', edgecolor='none', antialiased=False)
ax.set_xlim(-1.01, 1.01)

fig.colorbar(surf, shrink=0.5, aspect=5)

# Plot the axis in the middle of the figure
xline=((min(X[:,0]),max(X[:,0])),(0,0),(0,0))
ax.plot(xline[0],xline[1],xline[2],'grey')
yline=((0,0),(min(Y[:,1]),max(Y[:,1])),(0,0))
ax.plot(yline[0],yline[1],yline[2],'grey')
zline=((0,0),(0,0),(min(Z[:,2]),max(Z[:,2])))
ax.plot(zline[0],zline[1],zline[2],'grey')

ax.view_init(30,220) # Camera angel

ax.set_title('surface');

By using the above code, I obtain the figure like this

在此处输入图片说明

What I really want is to plot a 3-D axis origin figure like the following:

在此处输入图片说明

How to eliminate the margin and put the axis in the middle of the graph?

Here is a solution using plotly .

使用 plotly 的 3d 图,轴箭头位于中心

Short explanation

The code is represented below, but here I want to give the most important remarks

  • There is no way(?) to move the actual axis arrows to the center. What is done here, is that I've created the axis arrows from three parts: vector and a 3d cone at the end of the vector, and then later added the axis label as annotation . I put it literally "In the middle of the figure" but the position & arrow appearance can be really easily changed by modifying get_arrow() .
  • The "box" around the figure is also slightly hackish: I've changed the tickvals and range parameters for the layout.scene.xaxis , layout.scene.yaxis and layout.scene.zaxis . To have only two values, so that drawing a grid will show the box like this. If you would like to show the normal grid also, this should be done with vectors, too (like the arrows).
  • I was not sure if you want to include the color scale or not, but that can be added by just changing showscale=True for the go.Surface .
  • Another option that you might want to consider is to move the data instead of the axis (although, you still need to draw the arrows). This might make more sense in some use cases.

Code

import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go


def get_data():
    x = np.linspace(-1, 1, 10)
    y = np.linspace(-1, 1, 10)

    X, Y = np.meshgrid(x, y)

    Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])
    return X, Y, Z


# One-color arrows & arrowheads
colorscale = [
    [0, "rgb(84,48,5)"],
    [1, "rgb(84,48,5)"],
]

X, Y, Z = get_data()

data = {
    key: {
        "min": v.min(),
        "max": v.max(),
        "mid": (v.max() + v.min()) / 2,
        "range": v.max() - v.min(),
    }
    for (key, v) in dict(x=X, y=Y, z=Z).items()
}


def get_arrow(axisname="x"):

    # Create arrow body
    body = go.Scatter3d(
        marker=dict(size=1, color=colorscale[0][1]),
        line=dict(color=colorscale[0][1], width=3),
        showlegend=False,  # hide the legend
    )

    head = go.Cone(
        sizeref=0.1,
        autocolorscale=None,
        colorscale=colorscale,
        showscale=False,  # disable additional colorscale for arrowheads
        hovertext=axisname,
    )
    for ax, direction in zip(("x", "y", "z"), ("u", "v", "w")):
        if ax == axisname:
            body[ax] = data[ax]["min"], data[ax]["max"]
            head[ax] = [data[ax]["max"]]
            head[direction] = [1]
        else:
            body[ax] = data[ax]["mid"], data[ax]["mid"]
            head[ax] = [data[ax]["mid"]]
            head[direction] = [0]

    return [body, head]


def add_axis_arrows(fig):
    for ax in ("x", "y", "z"):
        for item in get_arrow(ax):
            fig.add_trace(item)


def get_annotation_for_ax(ax):
    d = dict(showarrow=False, text=ax, xanchor="left", font=dict(color="#1f1f1f"))
    for ax_ in ("x", "y", "z"):
        if ax_ == ax:
            d[ax_] = data[ax]["max"] - data[ax]["range"] * 0.05
        else:
            d[ax_] = data[ax_]["mid"]

    if ax in {"x", "y"}:
        d["xshift"] = 15

    return d


def get_axis_names():
    return [get_annotation_for_ax(ax) for ax in ("x", "y", "z")]


def get_scene_axis(axisname="x"):

    return dict(
        title="",  # remove axis label (x,y,z)
        showbackground=False,
        visible=True,
        showticklabels=False,  # hide numeric values of axes
        showgrid=True,  # Show box around plot
        gridcolor="grey",  # Box color
        tickvals=[data[axisname]["min"], data[axisname]["max"]],  # Set box limits
        range=[
            data[axisname]["min"],
            data[axisname]["max"],
        ],  # Prevent extra lines around box
    )


fig = go.Figure(
    data=[
        go.Surface(
            z=Z,
            x=X,
            y=Y,
            opacity=0.6,
            showscale=False,  # Set to True to show colorscale
        ),
    ],
    layout=dict(
        title="surface",
        autosize=True,
        width=700,
        height=500,
        margin=dict(l=20, r=20, b=25, t=25),
        scene=dict(
            xaxis=get_scene_axis("x"),
            yaxis=get_scene_axis("y"),
            zaxis=get_scene_axis("z"),
            annotations=get_axis_names(),
        ),
    ),
)

add_axis_arrows(fig)
fig.show()

I don't know of a proper way but only a workaround.

There is a detailed and very good description for the 2D case here with the options of

  • ax.axhline(y=0, color='k')
  • using splines: ax.spines['left'].set_position('zero')
  • or using the seaborn package seaborn.despine(ax=ax, offset=0)

However, I am afraid, they don't work this easy in 3D.

I am only aware of this workaround, where the (outer) axes are turned off ( ax.set_axis_off() ) and arrows are drawn ( ax.quiver() ) from the origin.

So you add

x, y, z = np.array([[-1,0,0],[0,-1,0],[0,0,0]])
u, v, w = np.array([[2,0,0],[0,2,0],[0,0,5]])
ax.quiver(x,y,z,u,v,w,arrow_length_ratio=0.1, color="black")
ax.set_axis_off()
plt.show()

to your code and you'll get this picture:
这张照片

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM