繁体   English   中英

pytorch nn.module如何保存子模块

[英]how pytorch nn.module save submodule

我对pytorch nn.module的工作方式有疑问

import torch
import torch.nn as nn



class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.sub_module = nn.Linear(10, 5)
        self.value = 3

net = Net()
print(net.__dict__)

产量

{'_modules': OrderedDict([('sub_module', Linear (10 -> 5))]),  'value': 3, ...}

我知道一个类的每个属性都应存储在__dict__中 ,为什么要在其中包含value(一个int值),而不要包含sub_module(一个nn.Module),而是将sub_module存储在_modules中

我读了nn.Module实现的代码,但我没有弄清楚。 有人有什么想法吗?

谢谢 !!

我会尽量保持简单。

例如,每次在Net类中创建一个新项目: self.sub_module = nn.Linear(10, 5)它都会调用其父类的方法__setattr__ ,在本例中为nn.Module 然后,在__setattr__方法内部,将参数存储到它们所属的字典中。 在这种情况下,由于nn.Linear是模块,因此将其存储到_modules字典。

这是在Modulehttps://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L389内执行此操作的代码

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM