繁体   English   中英

尝试打印犬种的类名,但它不断说列表索引超出范围

[英]Trying to print class names for dog breed but it keeps saying list index out of range

我正在使用resnet模型对狗的品种进行分类,但是当我尝试打印带有狗的品种标签的图像时,它表示列表索引超出范围。 这是我的代码:

import torchvision.models as models
import torch.nn as nn


model_transfer = models.resnet18(pretrained=True)

if use_cuda:
    model_transfer = model_transfer.cuda()

model_transfer.fc.out_features = 133

然后,我训练模型并获得超过70%的犬种准确性。

然后这是我的代码来对狗进行分类并打印狗的品种:

data_transfer = {'train': 
 datasets.ImageFolder('/data/dog_images/train',transform=transforms.Compose([transforms.RandomResizedCrop(224),transforms.ToTensor()]))}
class_names[0]
class_names = [item[4:].replace("_", " ") for item in data_transfer['train'].classes]

def predict_breed_transfer(img_path):

    image = Image.open(img_path)

    # large images will slow down processing


    in_transform = transforms.Compose([
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)

    image = image

    output = model_transfer(image)
    pred = torch.argmax(output)

    return class_names[pred]
    predict_breed_transfer('images/Labrador_retriever_06455.jpg')

代码总是会由于某种原因而预测狗是错误的,然后当我尝试打印出图像和标签时:

import matplotlib.pyplot as plt
def run_app(img_path):
    img = Image.open(img_path)
    dog = dog_detector(img_path)
    if not dog: 
        print('hello, human!')
        plt.imshow(img)
        print('You look like a ... ')
        print(predict_breed_transfer(img_path))
    if dog: 
        print('hello, dog!')
        print('Your predicted breed is ....')
        print(predict_breed_transfer(img_path))
        plt.imshow(img)
    else: 
        print('Niether human nor dog')

并运行一个for循环,在某些狗图像上调用它,它将打印出某些品种,然后说列表索引超出范围,并且不显示任何图像。

class_names的长度为133。当我打印出resnet模型时,输出只有133个节点,有人知道为什么它说列表索引超出范围或为什么它如此不准确。

`IndexError                                Traceback (most recent 
call last)
<ipython-input-26-473a9ba884b5> in <module>()
      5 ## suggested code, below
      6 for file in np.hstack((human_files[:3], dog_files[:3])):
----> 7     run_app(file)
      8 
 <ipython-input-25-1d44200e44cc> in run_app(img_path)
      10         plt.show(img)
      11         print('You look like a ... ')
 ---> 12         print(predict_breed_transfer(img_path))
      13     if dog:
      14         print('hello, dog!')

 <ipython-input-20-a51fb205659e> in predict_breed_transfer(img_path)
      26     pred = torch.argmax(output)
      27 
 ---> 28     return class_names[pred]
      29 
predict_breed_transfer('images/Labrador_retriever_06455.jpg')
      30 

IndexError: list index out of range`

这是完整的错误

我想您有几个可以使用13个字符解决的问题。

首先,我建议@Alekhya Vemavarapu提出的建议-使用调试器运行代码以隔离每一行并检查输出。 这是使用pytorch动态图的最大好处之一

其次,最可能引起问题的argmax是您使用不正确的argmax语句。 您没有指定执行argmax的尺寸,因此PyTorch会自动展平图像并在全长矢量上执行操作。 因此,您得到一个介于0MB_Size x num_classes -1之间的数字 有关此方法,请参见官方文档

因此,由于您具有完全连接的层,因此我假设您的输出具有形状(MB_Size, num_classes) 如果是这样,则需要将代码更改为以下行:

pred = torch.argmax(output,dim=1)

就是这样。 否则,只需选择Logit的尺寸。

您要考虑的第三件事是训练配置可能对推断造成的影响以及其他影响。 例如,某些框架中的辍学可能需要将推论的输出乘以1/(1-p) (或者因为在训练时可以完成,所以不这样做),由于批量大小不同,可能会取消批量标准化,因此上。 此外,为减少内存消耗,不应计算任何梯度。 幸运的是,PyTorch开发人员非常周到,为此提供了torch.no_grad()model.eval()

我强烈建议对此进行练习,可能会用几个字母更改代码:

output = model_transfer.eval()(image)

完成了!

编辑
这是错误使用PyTorch框架,不阅读文档和不调试代码的简单用例。 以下代码是完全不正确的:

model_transfer.fc.out_features = 133

该行实际上并不创建新的完全连接的层。 它只是改变了该张量的属性。 在您的控制台中尝试:

import torch
a = torch.nn.Linear(1,2)
a.out_features = 3
print(a.bias.data.shape, a.weight.data.shape)

输出:

torch.Size([2]) torch.Size([2, 1])

这表示权重和偏差向量的实际矩阵保持其原始尺寸。
进行迁移学习的正确方法是保留主干(通常是卷积层,直到这些类型的模型中的卷积层完全连接)并用您的头部覆盖磁头(在这种情况下为FC层)。 如果原始模型中仅存在一层完全连接的层,则无需更改模型的前向通过,您就可以进行了。 由于此答案已经足够长,因此只需访问PyTorch文档中的“ 转移学习”教程即可了解如何完成。

祝你好运。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM