繁体   English   中英

加载和冻结预训练模型以与新网络结合

[英]Loading & Freezing a Pretrained Model to Combine with a New Network

我有一个预训练模型,并想在它之上构建一个分类器。 我正在尝试加载和冻结预训练模型的权重,并将其输出传递给我想要优化的新分类器。 这是我到目前为止所拥有的,我有nn.SequentialTypeError: forward() missing 1 required positional argument: 'x' error from the nn.Sequential行:

import model #model.py contains the architecture of the pretrained model

class Classifier(nn.Module):
    def __init__(self):
        ...
    def forward(self, x):
        ...

net = model.Model()
net.load_state_dict(checkpoint["net"])

for c in net.children():
    for param in child.parameters():
        params.requires_grad = False

model = nn.Sequential(nn.ModuleList(net()), Classifier())

TL; 博士

model = nn.Sequential(nn.ModuleList(net), Classifier())

您通过net() “调用” net.forward ,而不是Classifier()Classifier__init__方法。

在与 PyTorch 论坛的 @ptrblck 讨论后,我终于解决了这个问题。 该解决方案类似于夏嘉曦的答案,只是因为net包含的实例model.Model类,一个应该做的model = nn.Sequential(net, Classifier())来代替,而不调用nn.ModuleList()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM