简体   繁体   中英

Getting an Error in Pytorch: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

I seem to be having a problem with my code. The error occurs at:

x, predicted = torch.max(net(value).data.squeeze(), 1)

I'm not sure what the issue is, and I've tried everything to fix. From my understanding, there seems to be a problem with the tensor dimension. I'm not sure what else to do. Can anyone give me any suggestions or solutions on how to fix this problem? Thank you in advance.

class Network(nn.Module): #Class for the neural network
def __init__(self):
    super(Network, self).__init__()
    self.layer1 = nn.Linear(6, 10) #First number in the number of inputs(784 since 28x28 is 784.) Second number indicates the number of inputs for the hidden layer(can be any number).
    self.hidden = nn.Softmax() #Activation Function
    self.layer2 = nn.Linear(10, 1) #First number is the hidden layer number(same as first layer), second number is the number of outputs.
    self.layer3 = nn.Sigmoid()

def forward(self, x): #Feed-forward part of the neural network. We will will feed the input through every layer of our network.
    y = self.layer1(x)
    y = self.hidden(y)
    y = self.layer2(y)
    y = self.layer3(y)
    return y #Returns the result

net = Network()
loss_function = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

for x in range(1): #Amount of epochs over the dataset
for index, value in enumerate(new_train_loader):
    print(value)#This loop loops over every image in the dataset 
    #actual = value[0]
    actual_value = value[5]
    #print(value.size())
    #print(net(value).size())
    print("ACtual", actual_value)
    net(value)
    loss = loss_function(net(value), actual_value.unsqueeze(0)) #Updating our loss function for every image
    #Backpropogation
    optimizer.zero_grad() #Sets gradients to zero.
    loss.backward() #Computes gradients
    optimizer.step() #Updates gradients
    print("Loop #: ", str(x+1), "Index #: ", str(index+1), "Loss: ", loss.item())


right = 0
total = 0
for value in new_test_loader:
actual_value = value[5]
#print(torch.max(net(value).data, 1))
print(net(value).shape)
x, predicted = torch.max(net(value).data.squeeze(), 1)
total += actual_value.size(0)
right += (predicted==actual_value).sum().item()
print("Accuracy: " + str((100*right/total)))

I should also mention that i'm using the latest versions.

You are calling .squeeze() on the model's output, which removes all singular dimensions (dimensions that have size 1). Your model's output has size [batch_size, 1] , so .squeeze() removes the second dimension entirely, resulting in size [batch_size] . After, you're trying to take the maximum value across dimension 1, but the only dimension you have is the 0th dimension.

You don't need to take the maximum value in this case, since you have only one class as the output, and with the sigmoid at the end of your model you get values between [0, 1]. Since your are doing a binary classification that single class acts as two, namely either it's 0 or it's 1. So it can be seen as the probability that it is the class 1. Then you just need to set use a threshold of 0.5, meaning when the probability is over 0.5 it's class 1 and if the probability is under 0.5 it's the class 0. That's exactly what rounding does, therefore you can use torch.round .

output = net(value)
predicted = torch.round(output.squeeze())

On a side note, you are calling net(value) multiple times with the same value, and that means that its output is calculated multiple times as well, because it needs to go through the entire network again. That is unnecessary and you should just save the output in a variable. With this small network it isn't noticeable, but with larger networks that will take a lot of unnecessary time to recalculate the output.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM