![](/img/trans.png)
[英]What are the PyTorch's model.eval() + no_grad() equivalent in TensorFlow?
[英]What does model.eval() do in pytorch?
我正在使用此代碼,並且在某些情況下看到了model.eval()
。
我知道它應該允許我“評估我的模型”,但我不明白什么時候應該和不應該使用它,或者如果關閉如何關閉。
我想運行上面的代碼來訓練網絡,並且還能夠在每個時期運行驗證。 我還是做不到。
model.eval()
是模型的某些特定層/部分的一種開關,在訓練和推理(評估)期間表現不同。 例如,Dropouts Layers、BatchNorm Layers 等。您需要在模型評估期間關閉它們,而.eval()
會為您完成。 此外,評估/驗證的常見做法是使用torch.no_grad()
與model.eval()
配對使用以關閉梯度計算:
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
但是,不要忘記在 eval 步驟后回到training
模式:
# training step
...
model.train()
...
model.train() |
model.eval() |
---|---|
在訓練模式下設置模型: • 標准化層1使用每批統計數據 • 激活 Dropout 層2 |
集模型中的eval uation(推斷)模式: • 規范化層使用運行統計 • 停用 Dropout 層相當於 model.train(False) 。 |
您可以通過運行model.train()
關閉評估模式。 您應該在將模型作為推理引擎運行時使用它 - 即在測試、驗證和預測時(盡管實際上如果您的模型不包含任何不同行為的層,它不會有任何區別)。
BatchNorm
, InstanceNorm
model.eval
是的方法torch.nn.Module
:
eval()
將模塊設置為評估模式。
這僅對某些模塊有任何影響。 如果它們受到影響,請參閱特定模塊的文檔以了解其在訓練/評估模式下的行為的詳細信息,例如
Dropout
、BatchNorm
等。這相當於
self.train(False)
。
相反方法model.train
通過曼·古普塔很好地說明。
對上述答案的額外補充:
我最近開始使用Pytorch-lightning ,它將大部分樣板包裝在訓練-驗證-測試管道中。
除此之外,它通過允許包裝eval
和train
的train_step
和validation_step
回調使model.eval()
和model.train()
幾乎是多余的,所以你永遠不會忘記。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.