简体   繁体   English

Plot 二维矩阵方程或 3d 使用 matplotlib

[英]Plot a matrix equation in 2d or 3d using matplotlib

I have an equation as followed:我有一个方程如下:

y = x^T * A * x + b^T * x + c

where x, b, c are vectors in n space and A is a nxn matrix.其中 x, b, c 是 n 空间中的向量, A 是 nxn 矩阵。

I can plot a linear equation in matplotlib, but not sure how a matrix equation can be (if possible) shown also in a 3d plot. I can plot a linear equation in matplotlib, but not sure how a matrix equation can be (if possible) shown also in a 3d plot.

I tried with following code, A is given matrix and w, c and b are column vectors.我尝试使用以下代码,A 是矩阵,w,c 和 b 是列向量。 X and Y are mesh and Z is the solution. X 和 Y 是网格,Z 是解。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
# if using a Jupyter notebook, include:
%matplotlib inline

fig = plt.figure(figsize=(10,6))
ax1 = fig.add_subplot(111, projection='3d')

n = 50
i = -5.0
j = 5.0

A = np.random.randint(i, j, size=(n, n))
w = np.random.randint(i, j, size=(n, 1))
c = b = np.random.randint(i, j, size=(n, 1))

X,Y = np.meshgrid(n,n)
Z = w.T*A*w + b.T*w + c

mycmap = plt.get_cmap('gist_earth')
surf1 = ax1.plot_surface(X, A, Z, cmap=mycmap)
fig.colorbar(surf1, ax=ax1, shrink=0.5, aspect=10)

plt.show()

The resulting plot does not seem to be a satisfied plot.生成的 plot 似乎不是满意的 plot。

There are two problems in your code: 1) meshgrid was being used incorrectly (it needs two arrays, not two ints);您的代码中有两个问题:1) meshgrid使用不正确(它需要两个 arrays,而不是两个整数); 2) in the surface plot, you were using X, A, Z instead of X, Y, Z -- X, A, Z will work, and might make sense, but I'm guessing it wasn't your intention. 2) 在表面 plot 中,您使用的是X, A, Z而不是X, Y, Z -- X, A, Z会起作用,并且可能有意义,但我猜这不是您的意图。

Here's a working solution:这是一个有效的解决方案:

在此处输入图像描述

fig = plt.figure(figsize=(10,6))
ax1 = fig.add_subplot(111, projection='3d')

n = 10
i = -5.0
j = 5.0

A = np.random.randint(i, j, size=(n, n))
w = np.random.randint(i, j, size=(n, 1))
c = b = np.random.randint(i, j, size=(n, 1))

X,Y = np.meshgrid(np.arange(n),np.arange(n))
Z = w.T*A*w + b.T*w + c

mycmap = plt.get_cmap('gist_earth')
surf1 = ax1.plot_surface(X, Y, Z, cmap=mycmap)
fig.colorbar(surf1, ax=ax1, shrink=0.5, aspect=10)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM