[英]Pytorch: Converting a VGG model into a sequential model, but getting different outputs
Background: I'm working on an adversarial detector method which requires to access the outputs from each hidden layer.背景:我正在研究一种对抗检测器方法,该方法需要访问每个隐藏层的输出。 I loaded a pretrained VGG16 from torchvision.models
.我从 torchvision.models 加载了预训练的torchvision.models
。
To access the output from each hidden layer, I put it into a sequential model:要从每个隐藏层访问 output,我将其放入顺序 model:
vgg16 = models.vgg16(pretrained=True)
vgg16_seq = nn.Sequential(*(
list(list(vgg16.children())[0]) +
[nn.AdaptiveAvgPool2d((7, 7)), nn.Flatten()] +
list(list(vgg16.children())[2])))
Without nn.Flatten()
, the forward method will complaint about dimensions don't match between mat1
and mat2
.如果没有nn.Flatten()
,前向方法将抱怨mat1
和mat2
之间的尺寸不匹配。
I looked into the torchvision VGG implementation, it uses the [feature..., AvgPool, flatten, classifier...]
structure.我查看了torchvision VGG实现,它使用[feature..., AvgPool, flatten, classifier...]
结构。 Since AdaptiveAvgPool2d
layer and Flatten
layer have no parameters , I assume this should work, but I have different outputs.由于AdaptiveAvgPool2d
层和Flatten
层没有参数,我认为这应该有效,但我有不同的输出。
output1 = vgg16(X_small)
print(output1.size())
output2 = vgg16_seq(X_small)
print(output2.size())
torch.equal(output1, output2)
Problem: They are in the same dimension but different outputs.问题:它们在同一维度但输出不同。
torch.Size([32, 1000])火炬.Size([32, 1000])
torch.Size([32, 1000])火炬.Size([32, 1000])
False错误的
I tested the outputs right after the AdaptiveAvgPool2d
layer, the outputs are equal:我在AdaptiveAvgPool2d
层之后测试了输出,输出是相等的:
output1 = nn.Sequential(*list(vgg16.children())[:2])(X_small)
print(output1.size())
output2 = nn.Sequential(*list(vgg16_seq)[:32])(X_small)
print(output2.size())
torch.equal(output1, output2)
torch.Size([32, 512, 7, 7]) torch.Size([32, 512, 7, 7])
torch.Size([32, 512, 7, 7]) torch.Size([32, 512, 7, 7])
True真的
Can someone point out what went wrong?有人可以指出哪里出了问题吗? Thank you谢谢
You need to call the eval mode before doing inference.您需要在进行推理之前调用 eval 模式。
ie IE
vgg16.eval()
vgg16_seq.eval()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.