[英]model.eval for class with field of network - pytorch
我有一個 class model 具有預訓練的 resnet 字段,例如:
class A(nn.Module):
def __init__(self, **kwargs):
super(A, self).__init__()
self.resnet = get_resnet()
def forward(self, x):
return self.resnet(x)
...
現在我在做
model = A()
...
model.eval()
我可以覆蓋eval
、 train
函數嗎?
沒關系。
由於nn.Module.train()
像這樣遞歸運行。
self.training = mode
for module in self.children():
module.train(mode)
return self
而nn.Module.eval()
只是調用self.train(False)
所以只要self.resnet
是一個nn.Module
子類。 您不需要為此煩惱,實際上nn.Module
中的每個方法除了forward
都會影響所有子模塊。
您可以通過以下方式進行測試
model = A()
...
model.eval()
print(model.resnet.training) # should be False
如果你得到False
那么一切都很好。 如果你得到其他東西,那么get_resnet()
有問題。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.