簡體   English   中英

如何在 Pytorch 中保存/加載具有多個損失的模型檢查點?

[英]How to save/load a model checkpoint with several losses in Pytorch?

使用 Ubuntu 20.04、Pytorch 1.10.1。

我正在嘗試使用變壓器架構和多嵌入來解決音樂生成任務,以處理具有多種特征的令牌。

在每次訓練迭代中,我必須計算每個令牌特征的損失並將其存儲在一個向量中,然后我想我應該在一個檢查點中存儲一個包含所有它們(或類似的東西)的向量,而不是我的現在做這可以節省總損失。 我想知道如何將所有損失存儲在檢查點中(能夠在加載時繼續訓練),或者根本不需要它。

時代循環:

for epoch in range(0, epochs):
    
    print('Epoch: ', epoch)
    
    loss = trfrmr.train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
    loss_train.append(loss)
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss,
        }, "model_pop909_checkpoint.pth")

訓練循環:

for batch_num, batch in enumerate(dataloader):
    time_before = time.time()

    opt.zero_grad()

    x = batch[0].to(get_device())
    tgt = batch[1].to(get_device())

    # x is the input sequence (N,T,Z), that should be input into the transformer forward function as (T,N,Z)
    y = model(x.permute(1, 0, 2))

    # tgt is the real output sequence, of shape (N,T,Z), T is sequence length, N batch size, Z the different token types
    # y are the output logits, is a list of Z tensors of shape (T,N,C*) where C is the vocabulary size, and will vary depending on the token type (pitch, velocity etc...)
    losses = []
    for j in range(LEN_VOCAB):
        aux_loss = loss.forward(y[j].permute(1, 2, 0),
                                        tgt[..., j])  # shapes (N,C,T) and (N,T), see Pytorch cross-entropy for details
        losses.append(aux_loss)

    losses_sum = sum(losses)  # here we sum, but we could also have mean for instance

    losses_sum.backward()
    opt.step()

    if lr_scheduler is not None:
         lr_scheduler.step()

    lr = opt.param_groups[0]['lr']
                
    loss_hist.append(losses_sum)
    if batch_num == num_iters:
       break

提前致謝。

幾小時后編輯:我的具體問題的解決方案

問題是當再次加載模型時,我沒有正確執行(不加載優化器參數,而只加載模型參數)。 現在在我的代碼中,在循環開始時,我執行以下操作:

if loaded:
    print('Loading model and optimizer...')
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    print('Loaded succesfully!')

我還加載了時代:

epoch = 0
if loaded:
    print('Loading epoch value...')
    epoch = checkpoint['epoch'] 
    print('Loaded succesfully!')

據我從您的代碼中可以看出,您的損失函數沒有自定義的可學習參數; 每次模型迭代時都會重新計算它。 因此,除了保留它的歷史之外,沒有必要保存它的價值; 不需要從檢查站繼續訓練。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM