I want to replace the linear layer of the 3D Resnet , which can be downloaded from the pytorch hub.
I can get the name of the linear layer by using the following code:
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.Linear):
print(name, layer)
blocks.5.proj Linear(in_features=2048, out_features=400, bias=True)
I cannot simply use model.blocks.5.proj = nn.Linear(2048, 10)
, because the .5.
throws me a syntax error. Instead I tried to iterate over the modules and replace the linear layer:
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.Linear):
model._modules[name] = torch.nn.Linear(2048, 10)
For some reason, this also doesn't work. Instead, it simply creates an additional linear layer with the same name:
blocks.5.proj Linear(in_features=2048, out_features=400, bias=True)
blocks.5.proj Linear(in_features=2048, out_features=10, bias=True)
Can someone help me out here?
The integer from the printed layer indicates that blocks
is an nn.Sequential
module. You can access a specific layer in the nn.Sequential
module with regular array indexing.
Try something like:
blocks[5].proj = torch.nn.Linear(2048, 10)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.