[英]How to save/load a model checkpoint with several losses in Pytorch?
使用 Ubuntu 20.04、Pytorch 1.10.1。
我正在嘗試使用變壓器架構和多嵌入來解決音樂生成任務,以處理具有多種特征的令牌。
在每次訓練迭代中,我必須計算每個令牌特征的損失並將其存儲在一個向量中,然后我想我應該在一個檢查點中存儲一個包含所有它們(或類似的東西)的向量,而不是我的現在做這可以節省總損失。 我想知道如何將所有損失存儲在檢查點中(能夠在加載時繼續訓練),或者根本不需要它。
時代循環:
for epoch in range(0, epochs):
print('Epoch: ', epoch)
loss = trfrmr.train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
loss_train.append(loss)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'loss': loss,
}, "model_pop909_checkpoint.pth")
訓練循環:
for batch_num, batch in enumerate(dataloader):
time_before = time.time()
opt.zero_grad()
x = batch[0].to(get_device())
tgt = batch[1].to(get_device())
# x is the input sequence (N,T,Z), that should be input into the transformer forward function as (T,N,Z)
y = model(x.permute(1, 0, 2))
# tgt is the real output sequence, of shape (N,T,Z), T is sequence length, N batch size, Z the different token types
# y are the output logits, is a list of Z tensors of shape (T,N,C*) where C is the vocabulary size, and will vary depending on the token type (pitch, velocity etc...)
losses = []
for j in range(LEN_VOCAB):
aux_loss = loss.forward(y[j].permute(1, 2, 0),
tgt[..., j]) # shapes (N,C,T) and (N,T), see Pytorch cross-entropy for details
losses.append(aux_loss)
losses_sum = sum(losses) # here we sum, but we could also have mean for instance
losses_sum.backward()
opt.step()
if lr_scheduler is not None:
lr_scheduler.step()
lr = opt.param_groups[0]['lr']
loss_hist.append(losses_sum)
if batch_num == num_iters:
break
提前致謝。
幾小時后編輯:我的具體問題的解決方案
問題是當再次加載模型時,我沒有正確執行(不加載優化器參數,而只加載模型參數)。 現在在我的代碼中,在循環開始時,我執行以下操作:
if loaded:
print('Loading model and optimizer...')
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
opt.load_state_dict(checkpoint['optimizer_state_dict'])
print('Loaded succesfully!')
我還加載了時代:
epoch = 0
if loaded:
print('Loading epoch value...')
epoch = checkpoint['epoch']
print('Loaded succesfully!')
據我從您的代碼中可以看出,您的損失函數沒有自定義的可學習參數; 每次模型迭代時都會重新計算它。 因此,除了保留它的歷史之外,沒有必要保存它的價值; 不需要從檢查站繼續訓練。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.