简体   繁体   中英

How can I have submodules of a PyTorch Module that are not attributes of the module

I would like to have a PyTorch sub-class of Module that keeps sub-modules in a list (because there may be a variable number of sub-modules depending on the constructor's arguments). I set this list in the following way:

self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]

According to this and this question, a submodule is only registered by __setattr__ , when a Module object is assigned to an attribute of self . Because hidden_layers is not assigned an object of type Module , the submodules in the list are not registered as submodules, and as a result self.parameters() does not iterate over the submodules' parameters.

I suppose I could explicitly call __subattr__ for each element of the list but that would be quite ugly. Is there a more correct way to register a submodule that is not a direct attribute of Module ?

Use nn.ModuleList .

self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])

As answered nn.ModuleList is what you want.

What you can also use is nn.Sequential . You can create a list of layers and then combine them via nn.Sequential , which will just act as a wrapper and combines all layers to essential one layer/module. This has the advantage that you only need one call to forward it through all the layers, which is nice if you have a dynamic count of modules, so you don't have to write the loops on your own.

One example would be in the pytorch ResNet code: https://github.com/pytorch/vision/blob/497744b9d510ff2df756f479ee5a19fce0d579b6/torchvision/models/re.net.py#L177

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM