简体   繁体   English

如何根据output张量去掉pytorch model的一个预测头?

[英]How to remove a prediction head from pytorch model based on the output tensor?

I am working on a ViT (Vision Transformer) related project and some low level definition is deep inside timm library, which I can not change.我正在从事与 ViT(Vision Transformer)相关的项目,一些低级定义在 timm 库的深处,我无法更改。 The low level library definition involves a linear classification prediction head, which is not a part of my.network.低级库定义涉及线性分类预测头,它不是 my.network 的一部分。

Every thing was fine until I switched to DDP parallel implementation.在我切换到 DDP 并行实现之前,一切都很好。 Pytorch complained about some parameters which didn't contribute to the loss, and it instructed me to use “find_unused_parameters=True”. Pytorch 抱怨一些参数对损失没有贡献,它指示我使用“find_unused_parameters=True”。 In fact, it is a common scenario and it worked again if I added this “find_unused_parameters=True” to the training routine.事实上,这是一个常见的场景,如果我将这个“find_unused_parameters=True”添加到训练例程中,它会再次起作用。 However, I am only allowed to change the model definition in our code base, but I cannot modify anything related to training …但是,我只能更改我们代码库中的 model 定义,但我不能修改任何与训练相关的内容……

So I guess the only thing I can do right now, is to “remove” the linear head from the model. Although I cannot dig into the low level definition of ViT, but I can output this tensor like this:所以我想我现在唯一能做的就是从 model 中“移除”线性头。虽然我无法深入研究 ViT 的低级定义,但我可以像这样 output 这个张量:

encoder_output,   linear_head_output =  ViT(input)

Is it possible to remove this linear prediction head based on this linear_head_output tensor?是否可以根据这个 linear_head_output 张量移除这个线性预测头?

Just set the num_classes=0 when you create your ViT model by calling timm.create_model() .只需在调用timm.create_model()创建 ViT model 时设置num_classes=0即可。

Here is an example from TIMM documentation on Feature Extraction :以下是TIMM 文档中有关特征提取的示例:

import torch
import timm
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
o = m(torch.randn(2, 3, 224, 224))
print(f'Unpooled shape: {o.shape}')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM