[英]Multiclass classification, IndexError: Target 2 is out of bounds
我面臨與使用 Pytorch neural.net 的某些葯物的活動相關的多類分類問題,我有三個活動類(0、1 和 2),為了解決我采用一對一方法的問題,從而創建了三個二進制分類器:0 vs. 1、1 vs. 2 和 2 vs. 0。當我訓練第二個分類器(1 類 vs. class 2)時,出現以下錯誤:
IndexError: Target 2 is out of bounds.
有沒有一種方法可以在不重新分配標簽的情況下解決它? 謝謝你們!
這是 my.net,是一個圖同構網絡構建與 Pytorch 幾何:
class GIN1(torch.nn.Module):
def __init__(self, h):
super(GIN1, self).__init__()
dim_h_conv = h
dim_h_fc = dim_h_conv*5
# Convolutional layers
self.conv1 = GINConv(Sequential(Linear(14, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv2 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv3 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv4 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
self.conv5 = GINConv(Sequential(Linear(dim_h_conv, dim_h_conv),
BatchNorm1d(dim_h_conv), ReLU(),
Linear(dim_h_conv, dim_h_conv), ReLU()))
# Fully connected layers
self.lin1 = Linear(dim_h_fc, dim_h_fc)
self.lin2 = Linear(dim_h_fc, 2)
self.initialize_w()
def forward(self, x, edge_index, batch):
h1 = self.conv1(x, edge_index)
h2 = self.conv2(h1, edge_index)
h3 = self.conv3(h2, edge_index)
h4 = self.conv4(h3, edge_index)
h5 = self.conv5(h4, edge_index)
# Graph level readout
h1 = global_add_pool(h1, batch)
h2 = global_add_pool(h2, batch)
h3 = global_add_pool(h3, batch)
h4 = global_add_pool(h4, batch)
h5 = global_add_pool(h5, batch)
# Concatenate graph embeddings
h = torch.cat((h1, h2, h3, h4, h5), dim=1)
# Classifier
h = self.lin1(h)
h = h.relu()
h = F.dropout(h, p=hp_gin1['p'], training=self.training)
h = self.lin2(h)
h = F.log_softmax(h, dim=1)
return h
def initialize_w(self):
for m in self.modules():
if isinstance(m, Linear):
torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
torch.nn.init.constant_(m.bias, 0)
if isinstance(m, BatchNorm1d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
這是我的訓練循環:
gin2 = GIN2(h=hp_gin2['h']) #40
optimizer = torch.optim.Adam(gin2.parameters(), lr=hp_gin2['lr'])
criterion = torch.nn.CrossEntropyLoss()
def train(train_loader):
gin2.train()
loss_all = 0
for data in train_loader:
output = gin2(data.x, data.edge_index, data.batch)
loss = criterion(output, data.y)
l2_lambda = hp_gin2['lambda']
l2_norm = sum(p.pow(2.0).sum()
for p in gin2.parameters())
loss = loss + l2_lambda * l2_norm
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_all += loss.item() * data.num_graphs
return loss_all / len(train_loader.dataset)
def test_loss(loader):
total_loss_val = 0
with torch.no_grad():
for data in loader:
output = gin2(data.x, data.edge_index, data.batch)
batch_loss = criterion(output, data.y)
total_loss_val += batch_loss.item() * data.num_graphs
return total_loss_val / len(loader.dataset)
def test(loader):
gin2.eval()
correct = 0
for data in loader:
output = gin2(data.x, data.edge_index, data.batch)
accuracy = Accuracy(average='macro', num_classes=2)
acc = accuracy(output, data.y)
return acc
OP 需要將其 model 的 output 維度與 label 類的數量相匹配(參見討論)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.