简体   繁体   中英

Python. Multiple plots via for loops, fixing axis

I am new to python. I have to plot some data at different time iterates. The plot is a 3d scatter plot. The plot has some errors I would like to fix: see the plot at three different time instances (first, middle and last)

first

middle

last

  • As you can see there is a box around each image which is kind of cut off by the title "graph title". I want to remove this box line (I dont understand where it is coming from). Note I want to keep the axis title.
  • In the middle and last image the numbers on the coordinate axis seem to be overlapping, I just want each of the three axis to be fixed for each image.

How can I edit my code to do the above.

fig, ax = plt.subplots()

for n in range(10):
    #labels
    ax=plt.axes(projection='3d') 
    ax.set_title('graph title')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.set_xlim(left=-10, right=20)
    ax.set_ylim(bottom=-10, top=20)
    ax.set_zlim(bottom=-10, top=20)

    #plotting
    x=data[n]
    ax.scatter(x[:,0],x[:,1],x[:,2])
    plt.savefig(f'fig_{n}.png')
    plt.cla() # needed to remove the plot because savefig doesn't clear it

The main issue is that you did not notice that you created multiple axes on the same figure.

You first create one with fig, ax = plt.subplots() , and other ones in the for loop with ax=plt.axes(projection='3d') . This is where the box comes from, it is the box of the axes drawn below. And it is also why you have overlapping ticks on coordinates axes.

Also, if you create the ax only once, then no need to set its title, labels etc. in the for loop:

import matplotlib.pyplot as plt
import numpy as np

T = 4  # time
N = 34  # samples
D = 3  # x, y, z

data = np.random.rand(T, N, D)

fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
ax.set_title("graph title")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.set_xlim(left=0, right=1)
ax.set_ylim(bottom=0, top=1)
ax.set_zlim(bottom=0, top=1)


colors = iter(["blue", "green", "red", "yellow", "pink"])
for n, points in enumerate(data):
    x, y, z = points.T
    scat = ax.scatter(x, y, z, c=next(colors))
    fig.savefig(f'fig_{n}.png')
    scat.remove()  # ax.cla() clears too much (title etc.)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM