[英]How to solve “RuntimeError: 1D target tensor expected, multi-target not supported” in multi-class classification?
[英]how to solve this (Pytorch RuntimeError: 1D target tensor expected, multi-target not supported)
我是 pytorch 和深度学习的新手
我的数据集 53502 x 58,
我的代码有问题
model = nn.Sequential(
nn.Linear(58,64),
nn.ReLU(),
nn.Linear(64,32),
nn.ReLU(),
nn.Linear(32,16),
nn.ReLU(),
nn.Linear(16,2),
nn.LogSoftmax(1)
)
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr = 0.0001)
epoch = 500
train_cost, test_cost = [], []
for i in range(epoch):
model.train()
cost = 0
for feature, target in trainloader:
output = model(feature) #feedforward
loss = criterion(output, target) #loss
loss.backward() #backprop
optimizer.step() #update weight
optimizer.zero_grad() #zero grad
cost += loss.item() * feature.shape[0]
train_cost.append(cost / len(train_set))
with torch.no_grad():
model.eval()
cost = 0
for feature, target in testloader:
output = model(feature) #feedforward
loss = criterion(output, target) #loss
cost += loss.item() * feature.shape
test_cost.append(cost / len(test_set))
print(f'\repoch {i+1}/{epoch} | train_cost: {train_cost[-1]} | test_cost : {test_cost[-1]}', end = "")
然后我遇到这样的问题
2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
-> 2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2265 elif dim == 4:
2266 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1D target tensor expected, multi-target not supported
怎么了? 如何解决这个问题呢? 为什么会这样?
非常感谢您!
使用NLLLoss
时,目标张量必须包含标签的索引表示,而不是 one-hot。 例如:
我想这就是你的目标的样子:
target = [0, 0, 1, 0]
只需将其转换为1
的索引数字即可:
[0, 0, 1, 0] -> [2]
[1, 0, 0, 0] -> [0]
[0, 0, 0, 1] -> [3]
然后将其转换为长张量,即:
target = [2]
target = torch.Tensor(target).type(torch.LongTensor)
可能令人困惑的是,您的 output 是一个具有类长度的张量,而您的目标是一个数字,但事实就是如此。
您可以在此处自行查看。
我得到了同样的错误信息,原因是targets
就像多维张量而不是一维:
tensor([[0],
[0],
[0],
...,
[9],
[9],
[9]], dtype=torch.int32)
并使用torch.flatten(targets)
解决了我的问题。 目标现在具有一维张量形状:
tensor([0, 0, 0, ..., 9, 9, 9], dtype=torch.int32)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.