简体   繁体   中英

model.eval for class with field of network - pytorch

I have a class model with field of pre-trained resnet something like:

class A(nn.Module):
    def __init__(self, **kwargs):
        super(A, self).__init__()
        self.resnet = get_resnet()

    def forward(self, x):
        return self.resnet(x)

...

now Im doing

model = A()
...
model.eval()

Is it ok or shuld I overwrite the eval , train functions?

Short answer

It's OK.

Long answer

As the nn.Module.train() runs recursively like this.

self.training = mode
for module in self.children():
    module.train(mode)
return self

And the nn.Module.eval() is just calling self.train(False)

So as long as self.resnet is an nn.Module subclass. You don't need to bother about it and practically every method in nn.Module except forward will affect all the sub modules.

You can test this by

model = A()
...
model.eval()
print(model.resnet.training)  # should be False

If you get False then everything is fine. If you get something else then there's something wrong with the get_resnet() .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM