简体   繁体   English

如何使用 PyTorch CNN 保存图像路径

[英]How to save image paths using PyTorch CNN

I have been working with a CNN in PyTorch and I need to save the image path and its associated predicted probabilities for each class (in this case the classes are pass or fail).我一直在 PyTorch 中使用 CNN,我需要为每个 class 保存图像路径及其相关的预测概率(在这种情况下,类是通过或失败)。 This is my code to save the preds to a data frame:这是我将 preds 保存到数据框的代码:

preds_df = pd.DataFrame()   
class_labels = []


model_ft.eval()
for i, (inputs, labels) in enumerate(dataloaders['train']):
    inputs = inputs.to(device)
    labels = labels.to(device)
    class_labels.append(labels.tolist())
    output = model_ft(inputs)

    sm = torch.nn.Softmax()
    probabilities = sm(output) 
    arr = probabilities.data.cpu().numpy()
    df = pd.DataFrame(arr)

    preds_df = preds_df.append(df)


preds_df['prediction'] = preds_df.idxmax(axis=1)
class_list = [item for sublist in class_labels for item in sublist]
preds_df['label'] = class_list
preds_df.columns = ['pass (0)', 'fail (1)', 'prediction', 'label']


preds_df.to_csv('./zoom17CNN_preds.csv')

How can I save the image path as well for each file in the data loader?如何为数据加载器中的每个文件保存图像路径? Thank you!谢谢!

Thanks @akshayk07: I ended up iterating over my image directories and saving the current image name and preds that way:谢谢@akshayk07:我最终遍历了我的图像目录并以这种方式保存了当前图像名称和预测:

directory = "./get_preds/fail"

for filename in os.listdir(directory):
    school_id = filename[0:6]
    ids.append(school_id)
    to_open = "./get_preds/fail/" + filename
    png = Image.open(to_open)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),         
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img_t = transform(png)
    batch_t = torch.unsqueeze(img_t, 0)
    model_ft.eval()
    out = model_ft(batch_t)

    _, index = torch.max(out, 1)
    percentage = torch.nn.functional.softmax(out, dim=1)[0]
    class0.append(percentage[0].tolist())
    class1.append(percentage[1].tolist())

df['school_id'] = ids
df['pass (0)'] = class0
df['fail (1)'] = class1
df['label'] = 1 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM