I am using this code , and saw model.eval()
in some cases.
I understand it is supposed to allow me to "evaluate my model", but I don't understand when I should and shouldn't use it, or how to turn if off.
I would like to run the above code to train the network, and also be able to run validation every epoch. I wasn't able to do it still.
model.eval()
is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval()
will do it for you. In addition, the common practice for evaluating/validation is using torch.no_grad()
in pair with model.eval()
to turn off gradients computation:
# evaluate model:
model.eval()
with torch.no_grad():
...
out_data = model(data)
...
BUT, don't forget to turn back to training
mode after eval step:
# training step
...
model.train()
...
model.train() |
model.eval() |
---|---|
Sets model in train ing mode: • normalisation layers 1 use per-batch statistics • activates Dropout layers 2 |
Sets model in eval uation (inference) mode: • normalisation layers use running statistics • de-activates Dropout layers Equivalent to model.train(False) . |
You can turn off evaluation mode by running model.train()
. You should use it when running your model as an inference engine - ie when testing, validating, and predicting (though practically it will make no difference if your model does not include any of the differently behaving layers ).
BatchNorm
, InstanceNorm
model.eval
is a method of torch.nn.Module
:
eval()
Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, eg
Dropout
,BatchNorm
, etc.This is equivalent with
self.train(False)
.
The opposite method is model.train
explained nicely by Umang Gupta.
An extra addition to the above answers:
I recently started working with Pytorch-lightning , which wraps much of the boilerplate in the training-validation-testing pipelines.
Among other things, it makes model.eval()
and model.train()
near redundant by allowing the train_step
and validation_step
callbacks which wrap the eval
and train
so you never forget to.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.