繁体   English   中英

如何将代码分离为 pytorch cnn 的训练、验证和测试函数?

[英]How to seperate code into train, val and test functions for pytorch cnn?

我正在使用 pytorch 训练一个 cnn 并创建了一个训练循环。 当我执行优化和试验超参数调整时,我想将我的训练、验证和测试分成不同的函数。 我需要能够记录每个 function 的准确性和损失,以便生成 plot 个图表。 为此,我想创建一个返回精度的 function。

我对编码很陌生,想知道关于这个问题的最佳方法 go。 我觉得我的代码现在有点乱。 我需要能够在我的训练 function 中输入各种超参数进行实验。有人可以提供任何建议吗? 以下是我到目前为止能做的:

def train_model(model, optimizer, data_loader,  num_epochs, criterion=criterion):
  total_epochs = notebook.tqdm(range(num_epochs))

  for epoch in total_epochs:
    model.train()

    train_correct = 0.0
    train_running_loss=0.0
    train_total=0.0

    for i, (img, label) in enumerate(data_loader['train']):
      #uploading images and labels to GPU
      img = img.to(device)
      label = label.to(device)

      #training model
      outputs = model(img)

      #computing losss
      loss = criterion(outputs, label)

      #propagating the loss backwards
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_running_loss += loss.item()
    
      _, predicted = outputs.max(1)
      train_total += label.size(0)
      train_correct += predicted.eq(label).sum().item()
      
    train_loss=train_running_loss/len(data_loader['train'])
    train_accu=100.*correct/total

    print('Train Loss: %.3f | Train Accuracy: %.3f'%(train_loss,train_accu))

我还尝试制作一个函数来记录准确性:

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim = 1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

首先,请注意:

  • 除非您有某些特定的动机,否则应该在与训练集不同的数据集上执行验证(和测试),因此您应该使用不同的DataLoader 由于每个时期都有一个额外的 for 循环,因此计算时间会增加。
  • 在验证/测试之前始终调用model.eval()

也就是说,验证 function 的签名与train_model的签名非常相似

# criterion is passed if you want to register the validation loss too
def validate_model(model, eval_loader, criterion):
   ...

然后,在train_model中,在每个纪元之后,您可以调用 function validate_model并将返回的指标存储在一些数据结构( listtensor等)中,稍后将用于绘图。

在训练结束时,您可以使用相同的validate_model function 进行测试。

您可以使用Accuracy的准确性,而不是自己编写准确性

最后,如果你觉得需要升级,可以使用PyTorch LightningFastAI等深度学习训练框架。 还可以查看一些超参数调整库,例如Ray Tune

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM