I have a class model with field of pre-trained resnet something like:
class A(nn.Module):
def __init__(self, **kwargs):
super(A, self).__init__()
self.resnet = get_resnet()
def forward(self, x):
return self.resnet(x)
...
now Im doing
model = A()
...
model.eval()
Is it ok or shuld I overwrite the eval
, train
functions?
It's OK.
As the nn.Module.train()
runs recursively like this.
self.training = mode
for module in self.children():
module.train(mode)
return self
And the nn.Module.eval()
is just calling self.train(False)
So as long as self.resnet
is an nn.Module
subclass. You don't need to bother about it and practically every method in nn.Module
except forward
will affect all the sub modules.
You can test this by
model = A()
...
model.eval()
print(model.resnet.training) # should be False
If you get False
then everything is fine. If you get something else then there's something wrong with the get_resnet()
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.