繁体   English   中英

如何使用 PyTorch CNN 保存图像路径

[英]How to save image paths using PyTorch CNN

我一直在 PyTorch 中使用 CNN,我需要为每个 class 保存图像路径及其相关的预测概率(在这种情况下,类是通过或失败)。 这是我将 preds 保存到数据框的代码:

preds_df = pd.DataFrame()   
class_labels = []


model_ft.eval()
for i, (inputs, labels) in enumerate(dataloaders['train']):
    inputs = inputs.to(device)
    labels = labels.to(device)
    class_labels.append(labels.tolist())
    output = model_ft(inputs)

    sm = torch.nn.Softmax()
    probabilities = sm(output) 
    arr = probabilities.data.cpu().numpy()
    df = pd.DataFrame(arr)

    preds_df = preds_df.append(df)


preds_df['prediction'] = preds_df.idxmax(axis=1)
class_list = [item for sublist in class_labels for item in sublist]
preds_df['label'] = class_list
preds_df.columns = ['pass (0)', 'fail (1)', 'prediction', 'label']


preds_df.to_csv('./zoom17CNN_preds.csv')

如何为数据加载器中的每个文件保存图像路径? 谢谢!

谢谢@akshayk07:我最终遍历了我的图像目录并以这种方式保存了当前图像名称和预测:

directory = "./get_preds/fail"

for filename in os.listdir(directory):
    school_id = filename[0:6]
    ids.append(school_id)
    to_open = "./get_preds/fail/" + filename
    png = Image.open(to_open)

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),         
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    img_t = transform(png)
    batch_t = torch.unsqueeze(img_t, 0)
    model_ft.eval()
    out = model_ft(batch_t)

    _, index = torch.max(out, 1)
    percentage = torch.nn.functional.softmax(out, dim=1)[0]
    class0.append(percentage[0].tolist())
    class1.append(percentage[1].tolist())

df['school_id'] = ids
df['pass (0)'] = class0
df['fail (1)'] = class1
df['label'] = 1 

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM