简体   繁体   中英

Weight Normalization in PyTorch

An important weight normalization technique was introduced in this paper and has been included in PyTorch since long as follows:

    from torch.nn.utils import weight_norm
    weight_norm(nn.Conv2d(in_channles, out_channels))

From the docs I get to know, weight_norm does re-parametrization before each forward() pass. But I am not sure if this re-parameterization is also happening during the inference when everything is running inside with torch.no_grad() and the model is set to eval() mode.

Can someone please help me know if weight_norm is active only during training or during the inference mode as described above?

Thank you

I tested the "no_gard", it works!

For the "remove_weight_norm", I am still confused. I use WeightNorm(conv1d) a lot in my model. To export the model, I use the following code, with or without "remove_weight_norm" funciton which call the function "nn.utils.remove_weight_norm" to all related.

model.load_state_dict(checkpoint)
model = model.eval()
model.remove_weight_norm(); //with and without this code
remove_hooks(model)
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, 'model.pt')

Then I tested two models using C++ code with libtorch. But the results are not the same.

I am wondering what does weight_norm do in inference? Is it usefull?

I have finally figured out the problem.

Batch normalization learns two parameters during training and uses them for inference. Thus it is necessary to change its behaviour using eval() to tell not to modify them any further.

I then scrutinizingly checked the weight normalization paper and found it to be 'inherently deterministic'. It simply decouples the original weight vectors as product of two quantities as shown below.

w = g . v

Obviously either you use LHS for computing output or RHS it does not matter. However by decoupling it into two vectors and passing them to optimizer and deleting the w parameter better training is achieved. For reasons refer the paper where things are nicely described.

Thus it does not matter if weight normalization is removed or not during testing. To validate this I tried the following small code.

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm as wn
from torch.nn.utils import remove_weight_norm as wnr

# define the model 'm'
m = wn(nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=True))

ip = torch.rand(1,1,5,5)
target = torch.rand(1,1,5,5)
l1 = torch.nn.L1Loss()
optimizer = torch.optim.Adam(m.parameters())



# begin training
for _ in range(5):
    out = m(ip)
    loss = l1(out,target)
    loss.backward()
    optimizer.step()

with torch.no_grad():
    m.eval()
    print('\no/p after training with wn: {}'.format(m(ip)))
    wnr(m)
    print('\no/p after training without wn: {}'.format(m(ip)))

# begin testing
m2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3,padding=1, bias=True)
m2.load_state_dict(m.state_dict())

with torch.no_grad():
    m2.eval()
    out = m2(ip)
    print('\nOutput during testing and without weight_norm: {}'.format(out))

And the output is below,

o/p after training with wn: 
tensor([[[[0.0509, 0.3286, 0.4612, 0.1795, 0.0307],
          [0.1846, 0.3931, 0.5713, 0.2909, 0.4026],
          [0.1716, 0.5971, 0.4297, 0.0845, 0.6172],
          [0.2938, 0.2389, 0.4478, 0.5828, 0.6276],
          [0.1423, 0.2065, 0.5024, 0.3979, 0.3127]]]])

o/p after training without wn: 
tensor([[[[0.0509, 0.3286, 0.4612, 0.1795, 0.0307],
          [0.1846, 0.3931, 0.5713, 0.2909, 0.4026],
          [0.1716, 0.5971, 0.4297, 0.0845, 0.6172],
          [0.2938, 0.2389, 0.4478, 0.5828, 0.6276],
          [0.1423, 0.2065, 0.5024, 0.3979, 0.3127]]]])

Output during testing and without weight_norm: 
tensor([[[[0.0509, 0.3286, 0.4612, 0.1795, 0.0307],
          [0.1846, 0.3931, 0.5713, 0.2909, 0.4026],
          [0.1716, 0.5971, 0.4297, 0.0845, 0.6172],
          [0.2938, 0.2389, 0.4478, 0.5828, 0.6276],
          [0.1423, 0.2065, 0.5024, 0.3979, 0.3127]]]])

Please see that all the values are exactly same as only reparameterization is happening.

Regarding,

Then I tested two models using C++ code with libtorch. But the results are not the same.

See https://github.com/pytorch/pytorch/issues/21275 which reports a bug with TorchScript.

And regarding,

I am wondering what does weight_norm do in inference? Is it usefull?

The answer is it does nothing. you do x * 2 or x * (1+1) does not matter. It is not useful but not harmful either. So better remove it.

It should be active. .eval() effects your network layers (eg Dropout and BatchNorm layer). eval documentation

.no_grad() reduces memory and speeds up computation during inference. no_grad documentation I Think weight_norm isn't effected by any of this.

Greetings

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM