[英]IndexError: Target 1 is out of bounds
當我運行下面的程序時,它給了我一個錯誤。 問題似乎在於丟失 function 但我找不到。 我已經閱讀了 nn.CrossEntropyLoss 的 Pytorch 文檔,但仍然找不到問題。
圖像尺寸為 (1 x 256 x 256),批量大小為 1
我是 PyTorch 的新手,謝謝。
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
torch.manual_seed(0)
x = np.array(Image.open("cat.jpg"))
x = np.expand_dims(x, axis = 0)
x = np.expand_dims(x, axis = 0)
x = torch.from_numpy(x)
x = x.type(torch.FloatTensor) # shape = (1, 1, 256, 256)
def Conv(in_channels, out_channels, kernel=3, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel, stride, padding)
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.sequential = nn.Sequential(
Conv(1, 3),
Conv(3, 5),
nn.Flatten(),
nn.Linear(317520, 1),
nn.Sigmoid()
)
def forward(self, x):
y = self.sequential(x)
return y
def compute_loss(y_hat, y):
return nn.CrossEntropyLoss()(y_hat, y)
model = model()
y_hat = model(x)
loss = compute_loss(y_hat, torch.tensor([1]))
錯誤:
Traceback (most recent call last):
File "D:/Me/AI/Models/test.py", line 38, in <module>
**loss = compute_loss(y, torch.tensor([1]))**
File "D:/Me/AI/Models/test.py", line 33, in compute_loss
return nn.CrossEntropyLoss()(y_hat, y)
File "D:\Softwares\Anaconda\envs\deeplearning\lib\site-packages\torch\nn\modules\module.py", line 1054, in _call_impl
return forward_call(*input, **kwargs)
File "D:\Softwares\Anaconda\envs\deeplearning\lib\site-packages\torch\nn\modules\loss.py", line 1120, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "D:\Softwares\Anaconda\envs\deeplearning\lib\site-packages\torch\nn\functional.py", line 2824, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
**IndexError: Target 1 is out of bounds.**
Process finished with exit code 1
這看起來像一個二進制分類器 model:cat or not cat。 但是您正在使用 CrossEntropyLoss,當您有超過 2 個目標類時使用它。 所以你應該使用的是Binary Cross Entropy Loss 。
def compute_loss(y_hat, y):
return nn.BCELoss()(y_hat, y)
試試loss = compute_loss(y_hat, torch.tensor([0]))
我認為這種變化
nn.Linear(317520, 1)
-> nn.Linear(317520, 2)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.