简体   繁体   中英

error: Dimension out of range (expected to be in range of [-1, 0], but got 1) when training to train a CNN model

I am trying to follow a code snippet which uses re.net18 to train a binary classification model since I am supposed to train a multi-classification model, I modified the code a bit by changing the loss function, activation function

import torchmetrics

class CheastCancer(pl.LightningModule):

  def __init__(self,init_weights=True):
    super().__init__()

    self.model = torchvision.models.resnet18()
    self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.model.fc = torch.nn.Linear(in_features=512, out_features=3, bias=True)

    self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 1e-4)
    self.loss_fn = torch.nn.CrossEntropyLoss()

    self.train_acc = torchmetrics.Accuracy()
    self.val_acc = torchmetrics.Accuracy()

  def forward(self, data):
    pred = self.model(data)
    return pred

  def training_step(self, batch, batch_idx):
    img, label = batch
    label = label.float()
    pred = self(img)[:,0]
    loss = self.loss_fn(pred,label)

    self.log("Train Loss", loss)
    self.log("Step Train ACC", self.train_acc(torch.softmax(pred), label.int()))

    return loss
  
  def training_epoch_end(self, outs):
    self.log("Train ACC", self.train_acc.compute())

  def validation_step(self, batch, batch_idx):
    img, label = batch
    label = label.float()
    pred = self(img)[:,0]
    loss = self.loss_fn(pred,label)

    self.log("Train Loss", loss)
    self.log("Step Train ACC", self.val_acc(torch.poisson(pred), label.int()))

    return loss
  
  def training_epoch_end(self, outs):
    self.log("Train ACC", self.val_acc.compute())

  def configure_optimizers(self):
      return [self.optimizer]

However, I received the following error message:

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 11.2 M
1 | loss_fn   | CrossEntropyLoss | 0     
2 | train_acc | Accuracy         | 0     
3 | val_acc   | Accuracy         | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.687    Total estimated model params size (MB)
Sanity Checking:
0/? [00:00<?, ?it/s]
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  category=PossibleUserWarning,
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-132-5af489ac4112> in <module>()
----> 1 trainer.fit(model, train_loader, val_loader)

17 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844     if size_average is not None or reduce is not None:
   2845         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   2847 
   2848 

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

from what I have search so far, this error usually emerges when the dimension of the label is not the same as the prediction. But I am not sure how can I fix it? Could somebody help please.

If you see the docs you'll see that the targets input into CrossEntropyLoss are the class indices. I'm going to guess that maybe you've got it in one-hot format, so what you'll need to do is:

y = torch.Tensor([[0,0,1],[0,1,0]])
y = y.argmax(dim=1)

Other than that, see what the shape of your output is and see if it matches your expected shape of your labels.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM