简体   繁体   中英

How do I save a figure from PyTorch using Matplotlib?

I'm using:

trainset = datasets.MNIST('saved/train', download=True,
                          train=True, transform=transform)

valset = datasets.MNIST('saved/test', download=True,
                        train=False, transform=transform)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)


dataiter = iter(trainloader)
images, labels = dataiter.next()

fig = plt.figure()
print(images.shape)
print(labels.shape)

plt.plot(images[0].numpy().squeeze())

fig.savefig('figs/first.png')

However, this does not save the first image. It looks like: 在此处输入图片说明

So what am I doing wrong?

Use matplotlib.pyplot.imshow instead of matplotlib.pyplot.plot

plt.imshow(images[0].numpy().squeeze())
fig.savefig('first_fig.png')

This will save matplotlib figure to save only image you can use matplotlib.pyplot.imsave like

plt.imsave('first_imsave.png',images[0].numpy().squeeze())

or use torchvision.utils.save_image

utils.save_image(images[0],'first_utils.png')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM