簡體   English   中英

PyTorch:nn.Sequential() 中特定模塊的訪問權重

[英]PyTorch: access weights of a specific module in nn.Sequential()

當我在 PyTorch 中使用預定義模塊時,我通常可以很容易地訪問它的權重。 但是,如果我nn.Sequential()模塊包裝在nn.Sequential()我該如何訪問它們? rg:

class My_Model_1(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_1, self).__init__()
        self.layer = nn.Linear(D_in,D_out)
    def forward(self,x):
        out = self.layer(x)
        return out

class My_Model_2(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_2, self).__init__()
        self.layer = nn.Sequential(nn.Linear(D_in,D_out))
    def forward(self,x):
        out = self.layer(x)
        return out

model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)

我現在如何打印重量? model_2.layer.0.weight不起作用。

訪問權重的一種簡單方法是使用模型的state_dict()

這應該適用於您的情況:

for k, v in model_2.state_dict().iteritems():
    print("Layer {}".format(k))
    print(v)

另一種選擇是獲取modules()迭代器。 如果您事先知道圖層的類型,這也應該有效:

for layer in model_2.modules():
   if isinstance(layer, nn.Linear):
        print(layer.weight)

PyTorch 論壇,這是推薦的方式:

model_2.layer[0].weight

您可以使用_modules按名稱訪問模塊:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 3, 3)

    def forward(self, input):
        return self.conv1(input)

model = Net()
print(model._modules['conv1'])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM