簡體   English   中英

如何在 nn.LSTM 中取得 R2 分數 pytorch

[英]How to make R2 score in nn.LSTM pytorch

我試圖用 R2in nn.LSTM 損失 function 但我找不到任何關於它的文檔。 我已經使用了 pytorch 中的 RMSE 和 MAE 損失。

我的數據是一個時間序列,我正在做時間序列預測

這是我在數據訓練中使用 RMSE 損失 function 的代碼

model = LSTM_model(input_size=1, output_size=1, hidden_size=512, num_layers=2, dropout=0).to(device)
criterion = nn.MSELoss(reduction="sum")
optimizer = optim.Adam(model.parameters(), lr=0.001)
callback = Callback(model, early_stop_patience=10 ,outdir="model/lstm", plot_every=20,)


from tqdm.auto import tqdm

def loop_fn(mode, dataset, dataloader, model, criterion, optimizer,device):
    if mode =="train":
        model.train()
    elif mode =="test":
        model.eval()
    cost = 0
    for feature, target in tqdm(dataloader, desc=mode.title()):
        feature, target = feature.to(device), target.to(device)
        output , hidden = model(feature,None)
        loss = torch.sqrt(criterion(output,target))
        
        if mode =="train":
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        cost += loss.item() * feature.shape[0]
    cost = cost / len(dataset)
    return cost

這是開始數據訓練的代碼

while True :
    train_cost = loop_fn("train", train_set, trainloader, model, criterion, optimizer,device)
    with torch.no_grad():
        test_cost  = loop_fn("test", test_set, testloader, model, criterion, optimizer,device)
        
    callback.log(train_cost, test_cost)
    
    callback.save_checkpoint()
    
    callback.cost_runtime_plotting()
   
    
    if callback.early_stopping(model, monitor="test_cost"):
        callback.plot_cost()
        break

誰能幫我解決 R2 損失 function? 先感謝您

這是一個實現,

"""
From https://en.wikipedia.org/wiki/Coefficient_of_determination
"""
def r2_loss(output, target):
    target_mean = torch.mean(target)
    ss_tot = torch.sum((target - target_mean) ** 2)
    ss_res = torch.sum((target - output) ** 2)
    r2 = 1 - ss_res / ss_tot
    return r2

您可以按如下方式使用它,

loss = r2_loss(output, target)
loss.backward()

以下庫 function 已經實現了我對 Melike 的解決方案所做的評論:

from torchmetrics.functional import r2_score
loss = r2_score(output, target)
loss.backward()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM