简体   繁体   中英

Training and testing CNN with pytorch. With and without model.eval()

I have two questions:-

  1. I am trying to train a convolution neural network initialized with some pre trained weights (Netwrok contains batch normalization layers as well) (taking reference from here ). Before training I want to calculate a validation error using loss_fn = torch.nn.MSELoss().cuda() . And in the reference, the author is using model.eval() before calculating the validation error. But with that result, the CNN model is off from what it should be however when I comment out model.eval() , the output is good (what it should be with pre-trained weights). What could be reason behind it as I have read on many posts that model.eval should be used before testing the model and model.train() before training it.

  2. While calculating the validation error with pre-trained weights and above mentioned loss function what should be the batch size. Shouldn't it be 1 as i want output on each of my input, calculate error with ground truth and in the end take average of all results. If i use higher batch size error is increased. So question is can i use higher batch size if yes what should be the right way. In given code i have given err = float(loss_local) / num_samples but i observed without averaging ie err = float(loss_local) . Error is different for different batch size. I am doing this without model.eval right now.

    batch_size = 1
    data_path = 'path_to_data'
    dtype = torch.FloatTensor
    weight_file = 'path_to_weight_file'
    val_loader = torch.utils.data.DataLoader(NyuDepthLoader(data_path, val_lists),batch_size=batch_size, shuffle=True, drop_last=True)
    model = Model(batch_size)
    model.load_state_dict(load_weights(model, weight_file, dtype))
    loss_fn = torch.nn.MSELoss().cuda()
    # model.eval()

    with torch.no_grad():
        for input, depth in val_loader:
            input_var = Variable(input.type(dtype))
            depth_var = Variable(depth.type(dtype))

            output = model(input_var)

            input_rgb_image = input_var[0].data.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            input_gt_depth_image = depth_var[0][0].data.cpu().numpy().astype(np.float32)
            pred_depth_image = output[0].data.squeeze().cpu().numpy().astype(np.float32)
            print (format(type(depth_var)))
            pred_depth_image_resize = cv2.resize(pred_depth_image, dsize=(608, 456), interpolation=cv2.INTER_LINEAR)
            target_depth_transform = transforms.Compose([flow_transforms.ArrayToTensor()])
            pred_depth_image_tensor = target_depth_transform(pred_depth_image_resize)
            #both inputs to loss_fn are 'torch.Tensor'
            loss_local += loss_fn(pred_depth_image_tensor, depth_var)

            num_samples += 1
            print ('num_samples {}'.format(num_samples))

    err = float(loss_local) / num_samples
    print('val_error before train:', err)

What could be reason behind it as I have read on many posts that model.eval should be used before testing the model and model.train() before training it.

Note: testing the model is called inference.

As explained in the official documentation :

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

So this code must be present once you load the model from a file and do inference.

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

This is because dropout works as a regularization for preventing overfitting during training, it is not needed for inference. Same for the batch norms. When you use eval() this just sets module train label to False and affects only certain types of modules in particular Dropout and BatchNorm .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM