繁体   English   中英

如何使用python绘制3D轴原点图

[英]How to plot 3D axis-origin figure using python

我想绘制一个轴在图形中间的 3-D 表面。

我使用以下代码绘制图形:

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)

X, Y = np.meshgrid(x, y)

Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
       [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
        3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])


fig = plt.figure()
ax = plt.axes(projection='3d')
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                cmap='viridis', edgecolor='none', antialiased=False)
ax.set_xlim(-1.01, 1.01)

fig.colorbar(surf, shrink=0.5, aspect=5)

# Plot the axis in the middle of the figure
xline=((min(X[:,0]),max(X[:,0])),(0,0),(0,0))
ax.plot(xline[0],xline[1],xline[2],'grey')
yline=((0,0),(min(Y[:,1]),max(Y[:,1])),(0,0))
ax.plot(yline[0],yline[1],yline[2],'grey')
zline=((0,0),(0,0),(min(Z[:,2]),max(Z[:,2])))
ax.plot(zline[0],zline[1],zline[2],'grey')

ax.view_init(30,220) # Camera angel

ax.set_title('surface');

通过使用上面的代码,我得到了这样的图

在此处输入图片说明

我真正想要的是绘制一个 3-D 轴原点图,如下所示:

在此处输入图片说明

如何消除边距并将轴放在图形的中间?

这是使用plotly的解决方案。

使用 plotly 的 3d 图,轴箭头位于中心

简短说明

代码如下所示,但在这里我想给出最重要的评论

  • 没有办法(?)将实际的轴箭头移动到中心。 这里所做的是,我从三个部分创建了轴箭头:向量和向量末尾的3d 锥体,然后将轴标签添加为annotation 我把它的字面意思是“在图形的中间”,但是通过修改get_arrow()可以很容易地改变位置和箭头的外观。
  • 该图周围的“框”也有点tickvals :我已经更改了layout.scene.xaxislayout.scene.yaxislayout.scene.zaxistickvalsrange参数。 只有两个值,以便绘制网格将显示这样的框。 如果您还想显示普通网格,也应该使用矢量来完成(如箭头)。
  • 我不确定您是否要包括色标,但是可以通过更改showscale=Truego.Surface
  • 您可能要考虑的另一个选项是移动数据而不是轴(尽管您仍然需要绘制箭头)。 在某些用例中,这可能更有意义。

代码

import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go


def get_data():
    x = np.linspace(-1, 1, 10)
    y = np.linspace(-1, 1, 10)

    X, Y = np.meshgrid(x, y)

    Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
        [2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
            3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])
    return X, Y, Z


# One-color arrows & arrowheads
colorscale = [
    [0, "rgb(84,48,5)"],
    [1, "rgb(84,48,5)"],
]

X, Y, Z = get_data()

data = {
    key: {
        "min": v.min(),
        "max": v.max(),
        "mid": (v.max() + v.min()) / 2,
        "range": v.max() - v.min(),
    }
    for (key, v) in dict(x=X, y=Y, z=Z).items()
}


def get_arrow(axisname="x"):

    # Create arrow body
    body = go.Scatter3d(
        marker=dict(size=1, color=colorscale[0][1]),
        line=dict(color=colorscale[0][1], width=3),
        showlegend=False,  # hide the legend
    )

    head = go.Cone(
        sizeref=0.1,
        autocolorscale=None,
        colorscale=colorscale,
        showscale=False,  # disable additional colorscale for arrowheads
        hovertext=axisname,
    )
    for ax, direction in zip(("x", "y", "z"), ("u", "v", "w")):
        if ax == axisname:
            body[ax] = data[ax]["min"], data[ax]["max"]
            head[ax] = [data[ax]["max"]]
            head[direction] = [1]
        else:
            body[ax] = data[ax]["mid"], data[ax]["mid"]
            head[ax] = [data[ax]["mid"]]
            head[direction] = [0]

    return [body, head]


def add_axis_arrows(fig):
    for ax in ("x", "y", "z"):
        for item in get_arrow(ax):
            fig.add_trace(item)


def get_annotation_for_ax(ax):
    d = dict(showarrow=False, text=ax, xanchor="left", font=dict(color="#1f1f1f"))
    for ax_ in ("x", "y", "z"):
        if ax_ == ax:
            d[ax_] = data[ax]["max"] - data[ax]["range"] * 0.05
        else:
            d[ax_] = data[ax_]["mid"]

    if ax in {"x", "y"}:
        d["xshift"] = 15

    return d


def get_axis_names():
    return [get_annotation_for_ax(ax) for ax in ("x", "y", "z")]


def get_scene_axis(axisname="x"):

    return dict(
        title="",  # remove axis label (x,y,z)
        showbackground=False,
        visible=True,
        showticklabels=False,  # hide numeric values of axes
        showgrid=True,  # Show box around plot
        gridcolor="grey",  # Box color
        tickvals=[data[axisname]["min"], data[axisname]["max"]],  # Set box limits
        range=[
            data[axisname]["min"],
            data[axisname]["max"],
        ],  # Prevent extra lines around box
    )


fig = go.Figure(
    data=[
        go.Surface(
            z=Z,
            x=X,
            y=Y,
            opacity=0.6,
            showscale=False,  # Set to True to show colorscale
        ),
    ],
    layout=dict(
        title="surface",
        autosize=True,
        width=700,
        height=500,
        margin=dict(l=20, r=20, b=25, t=25),
        scene=dict(
            xaxis=get_scene_axis("x"),
            yaxis=get_scene_axis("y"),
            zaxis=get_scene_axis("z"),
            annotations=get_axis_names(),
        ),
    ),
)

add_axis_arrows(fig)
fig.show()

我不知道正确的方法,但只有一种解决方法。

没有为2D情况详细和很好的说明这里与选项

  • ax.axhline(y=0, color='k')
  • 使用样条: ax.spines['left'].set_position('zero')
  • 或使用 seaborn 包seaborn.despine(ax=ax, offset=0)

但是,恐怕它们在 3D 中不会这么容易工作。

我只知道这种解决方法,其中(外)轴关闭( ax.set_axis_off() )并从原点绘制箭头( ax.quiver() )。

所以你加

x, y, z = np.array([[-1,0,0],[0,-1,0],[0,0,0]])
u, v, w = np.array([[2,0,0],[0,2,0],[0,0,5]])
ax.quiver(x,y,z,u,v,w,arrow_length_ratio=0.1, color="black")
ax.set_axis_off()
plt.show()

到你的代码,你会得到这张照片:
这张照片

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM