简体   繁体   中英

how to save all the generated image in a folder in pytorch

I am trying to use data augmentation with pytorch. I want to save all the generated images in a folder ( target_dir ) with different numbering based on the batch index.

Here is my code. I am using epoch=100 and batch_size=128 .

import os



for batch_idx in range(BATCH_SIZE):
    
torchvision.utils.save_image(img_grid_fake, f"C:/UserspythonProjectgenerated_image/Fake_image%{batch_idx}d.png", global_step=step)

but i am only getting last 128 generated images, previous generated image are get deleted when next epoch run.

You need to save the images with f"Fake_image-{epoch}-{batch-idx}.png" so that both epoch and batch_idx are used in naming the files.

import os
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

target_dir = r"C:/Users/PycharmProjects/pythonProject/generated/generated_image/"

EPOCHS = 10
BATCH_SIZE = 64
GRID_SIZE = 9 # 9 images in each grid
NUM_ROWS = 3 # sqrt(GRID_SIZE)

# if you want all the images in a batch to make the image-grid, 
# set GRID_SIZE = BATCH_SIZE

train_dataset = YourFakeImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                              shuffle=True, transform=ToTensor())

for epoch in range(EPOCHS):
    for batch_idx, (X, y) in enumerate(train_dataloader):
        # assume X is the fake-image returned by the dataloader
        # and y is some target value for the X, also returned by the dataloader

        # ... do something with your images here
        # B, C, H, W = X.shape
        img_grid_fake = torchvision.utils.make_grid(X[:GRID_SIZE, ...], nrow=NUM_ROWS)
        filepath = os.path.join(target_dir, f"Fake_image-{epoch}-{batch_idx}.png")
        torchvision.utils.save_image(img_grid_fake, filepath)

NOTE : I cannot answer you properly, as your question does not specify a lot of details clearly (some of them are asked by others in the comments).

If you are making aa fake-image-grid, how are you doing that? With torchvision.utils.make_grid() ?

References

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM