[英]How to train features in different scales in deep learning model
我是深度学习的新手,我构建了一个非常简单的 model 来尝试训练我的数据。 我有两个特征输入: sex
和age
。 sex
为0
或1
, age
在25
到60
之间。 Output 只是0
表示此人没有这种疾病, 1
表示有这种疾病。
然而,当我训练我的 model 时,训练损失并没有减少。 看起来是因为我的两个功能在范围上非常不同。 那么我该如何解决这个问题呢? 任何建议将不胜感激。
我的代码在这里:
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(2,50),
nn.ReLU(),
nn.Linear(50,2)
)
def forward(self,x):
x = self.fc1(x)
x = F.softmax(x, dim=1)
return x
#Inputs
X = np.column_stack((sex,age))
X = torch.from_numpy(X).type(torch.FloatTensor)
y = torch.from_numpy(y).type(torch.LongTensor)
#Initialize the model
model = Net()
#Define loss criterion
criterion = nn.CrossEntropyLoss()
#Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 1000
losses = []
for i in range(epochs):
y_pred = model.forward(X)
#Compute Cross entropy loss
loss = criterion(y_pred,y)
#Add loss to the list
losses.append(loss.item())
#Clear the previous gradients
optimizer.zero_grad()
#Compute gradients
loss.backward()
#Adjust weights
optimizer.step()
_, predicted = torch.max(y_pred.data, 1)
if i % 50 == 0:
print(loss.item())
火车损失看起来像这样
0.9273738861083984
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
0.6992899179458618
编辑
谢谢您的意见。 抱歉,我没有清楚地解释我的问题。 这是我的网络的一部分,我的输入数据包含两部分:第一部分是一些信号数据,我使用 CNN model 对其进行训练,效果很好; 第二部分就是我上面提到的。 我的目标是合并两个模型以提高我的准确性。 我已经尝试过标准化,看起来它可以工作。 我想知道在预处理数据时是否总是需要进行规范化? 谢谢!
替代。
如果年龄在(25-60)
范围内具有离散值,那么一种可能的方法是学习这两个属性的嵌入,即sex
和age
。
例如,
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.sex_embed = nn.Embedding(2, 20)
self.age_embed = nn.Embedding(36, 50)
self.fc1 = nn.Sequential(
nn.Linear(70, 35),
nn.ReLU(),
nn.Linear(35, 2)
)
def forward(self, x):
# write the forward
在上面的示例中,我假设年龄值为 integer (25, 26, ..., 60)
,因此对于每个可能的值,我们可以学习向量表示。
所以,我建议学习20d
的性别表示和50d
的年龄表示。 您可以更改尺寸并进行实验以找到最佳值。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.