简体   繁体   中英

How to save image paths using PyTorch CNN

I have been working with a CNN in PyTorch and I need to save the image path and its associated predicted probabilities for each class (in this case the classes are pass or fail). This is my code to save the preds to a data frame:

preds_df = pd.DataFrame()   
class_labels = []


model_ft.eval()
for i, (inputs, labels) in enumerate(dataloaders['train']):
    inputs = inputs.to(device)
    labels = labels.to(device)
    class_labels.append(labels.tolist())
    output = model_ft(inputs)

    sm = torch.nn.Softmax()
    probabilities = sm(output) 
    arr = probabilities.data.cpu().numpy()
    df = pd.DataFrame(arr)

    preds_df = preds_df.append(df)


preds_df['prediction'] = preds_df.idxmax(axis=1)
class_list = [item for sublist in class_labels for item in sublist]
preds_df['label'] = class_list
preds_df.columns = ['pass (0)', 'fail (1)', 'prediction', 'label']


preds_df.to_csv('./zoom17CNN_preds.csv')

How can I save the image path as well for each file in the data loader? Thank you!

Thanks @akshayk07: I ended up iterating over my image directories and saving the current image name and preds that way:

directory = "./get_preds/fail"

for filename in os.listdir(directory):
    school_id = filename[0:6]
    ids.append(school_id)
    to_open = "./get_preds/fail/" + filename
    png = Image.open(to_open)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),         
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img_t = transform(png)
    batch_t = torch.unsqueeze(img_t, 0)
    model_ft.eval()
    out = model_ft(batch_t)

    _, index = torch.max(out, 1)
    percentage = torch.nn.functional.softmax(out, dim=1)[0]
    class0.append(percentage[0].tolist())
    class1.append(percentage[1].tolist())

df['school_id'] = ids
df['pass (0)'] = class0
df['fail (1)'] = class1
df['label'] = 1 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM