[英]using list in creating pytorch NN module
This code runs fine to create a simple feed-forward neural Network.这段代码运行良好,可以创建一个简单的前馈神经网络。 The layer (
torch.nn.Linear
) is assigned to the class variable by using self
.使用
self
将层 ( torch.nn.Linear
) 分配给 class 变量。
class MultipleRegression3L(torch.nn.Module):
def __init__(self, num_features):
super(MultipleRegression3L, self).__init__()
self.layer_1 = torch.nn.Linear(num_features, 16)
## more layers
self.relu = torch.nn.ReLU()
def forward(self, inputs):
x = self.relu(self.layer_1(inputs))
x = self.relu(self.layer_2(x))
x = self.relu(self.layer_3(x))
x = self.layer_out(x)
return (x)
def predict(self, test_inputs):
return self.forward(test_inputs)
However, when I tried to store the layer using the list:但是,当我尝试使用列表存储图层时:
class MultipleRegression(torch.nn.Module):
def __init__(self, num_features, params):
super(MultipleRegression, self).__init__()
number_of_layers = 3 if not 'number_of_layers' in params else params['number_of_layers']
number_of_neurons_in_each_layer = [16, 32, 16] if not 'number_of_neurons_in_each_layer' in params else params['number_of_neurons_in_each_layer']
activation_function = "relu" if not 'activation_function' in params else params['activation_function']
self.layers = []
v1 = num_features
for i in range(0, number_of_layers):
v2 = number_of_neurons_in_each_layer[i]
self.layers.append(torch.nn.Linear(v1, v2))
v1 = v2
self.layer_out = torch.nn.Linear(v2, 1)
if activation_function == "relu":
self.act_func = torch.nn.ReLU()
else:
raise Exception("Activation function %s is not supported" % (activation_function))
def forward(self, inputs):
x = self.act_func(self.layers[0](inputs))
for i in range(1, len(self.layers)):
x = self.act_func(self.layers[i](x))
x = self.layer_out(x)
return (x)
The two models do not behave the same way.这两个模型的行为方式不同。 What can be wrong here?
这里有什么问题?
Pytorch needs to keep the graph of the modules in the model, so using a list
does not work. Pytorch 需要保留 model 中的模块图,所以使用
list
不起作用。 Using self.layers = torch.nn.ModuleList()
fixed the problem.使用
self.layers = torch.nn.ModuleList()
解决了这个问题。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.