简体   繁体   中英

pytorch training loop ends with ''int' object has no attribute 'size' exception

The code I am posting below is just a small part of the application:

def train(self, training_reviews, training_labels):
        
        # make sure out we have a matching number of reviews and labels
        assert(len(training_reviews) == len(training_labels))
        
        # Keep track of correct predictions to display accuracy during training 
        correct_so_far = 0
        
        # Remember when we started for printing time statistics
        start = time.time()
        
        
        criterion = nn.CrossEntropyLoss()
        optimizer =  torch.optim.SGD(self.parameters(), lr=self.learning_rate)

        # loop through all the given reviews and run a forward and backward pass,
        # updating weights for every item
        for i in range(len(training_reviews)):
            
            # TODO: Get the next review and its correct label
            review = training_reviews[i]
            label = training_labels[i]
            print('processing item ',i)
            self.update_input_layer(review)
            output = self.forward(torch.from_numpy(self.layer_0).float()) 
            target = self.get_target_for_label(label)
            print('output ',output)
            print('target ',target)
            loss = criterion(output, target)

...
mlp = SentimentNetwork(reviews[:-1000],labels[:-1000], learning_rate=0.1)
mlp.train(reviews[:-1000],labels[:-1000])

and it ends with the exception in the title line when evaluating:

loss = criterion(output, target)

prior to that, the variables are as follows:

output  tensor([[0.5803]], grad_fn=<SigmoidBackward>)
target  1

Target should be a torch.Tensor variable. Use torch.tensor([target]) .

Additionally, you may want to use batches (so there are N samples and shape of torch.tensor is (N,) , same for target ).

Also see introductory tutorial about PyTorch, as you're not using batches, not running optimizer or not using torch.utils.data.Dataset and torch.utils.data.DataLoader as you probably should.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM