簡體   English   中英

model.eval 用於 class 與網絡領域 - pytorch

[英]model.eval for class with field of network - pytorch

我有一個 class model 具有預訓練的 resnet 字段,例如:

class A(nn.Module):
    def __init__(self, **kwargs):
        super(A, self).__init__()
        self.resnet = get_resnet()

    def forward(self, x):
        return self.resnet(x)

...

現在我在做

model = A()
...
model.eval()

我可以覆蓋evaltrain函數嗎?

簡短的回答

沒關系。

長答案

由於nn.Module.train()像這樣遞歸運行。

self.training = mode
for module in self.children():
    module.train(mode)
return self

nn.Module.eval()只是調用self.train(False)

所以只要self.resnet是一個nn.Module子類。 您不需要為此煩惱,實際上nn.Module中的每個方法除了forward都會影響所有子模塊。

您可以通過以下方式進行測試

model = A()
...
model.eval()
print(model.resnet.training)  # should be False

如果你得到False那么一切都很好。 如果你得到其他東西,那么get_resnet()有問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM