繁体   English   中英

PyTorch 中的 L1/L2 正则化

[英]L1/L2 regularization in PyTorch

如何在不手动计算的情况下在 PyTorch 中添加 L1/L2 正则化?

使用weight_decay > 0进行 L2 正则化:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

请参阅文档 为 L2 正则化的优化器添加weight_decay参数。

以前的答案虽然在技术上是正确的,但在性能方面效率低下并且不是太模块化(很难在每层的基础上应用,例如keras层提供的)。

PyTorch L2 实现

为什么 PyTorch 在torch.optim.Optimizer实例中实现L2

我们看一下torch.optim.SGD源码(目前为功能优化程序),特别是这部分:

for i, param in enumerate(params):
    d_p = d_p_list[i]
    # L2 weight decay specified HERE!
    if weight_decay != 0:
        d_p = d_p.add(param, alpha=weight_decay)
  • 可以看到, d_p (参数的导数,梯度)被修改并重新分配以加快计算速度(不保存临时变量)
  • 它具有O(N)复杂度,没有任何复杂的数学,例如pow
  • 它不涉及autograd扩展图形而无需任何需要

将其与O(n) **2操作、加法以及参与反向传播进行比较。

数学

让我们看看具有alpha正则化因子的L2方程(L1 ofc 也可以这样做):

L2

如果我们用L2正则化 wrt 参数w对任何损失求导(它与损失无关),我们得到:

L2 衍生

所以它只是为每个权重的梯度添加了alpha * weight weight! 这正是 PyTorch 在上面所做的!

L1 正则化层

使用这个(和一些 PyTorch 魔法),我们可以提出非常通用的 L1 正则化层,但让我们先看看L1的一阶导数( sgn是符号函数,返回1表示正输入,返回-1表示负, 0表示0 ) :

L1 导数

带有WeightDecay接口的完整代码位于torchlayers 第三方库中,提供了诸如仅对权重/偏差/特定命名的参数进行正则化等内容(免责声明:我是作者),但下面概述的想法的本质(见评论):

class L1(torch.nn.Module):
    def __init__(self, module, weight_decay):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not dependent on backprop incoming values, placeholder
    def _weight_decay_hook(self, *_):
        for param in self.module.parameters():
            # If there is no gradient or it was zeroed out
            # Zeroed out using optimizer.zero_grad() usually
            # Turn on if needed with grad accumulation/more safer way
            # if param.grad is None or torch.all(param.grad == 0.0):

            # Apply regularization on it
            param.grad = self.regularize(param)

    def regularize(self, parameter):
        # L1 regularization formula
        return self.weight_decay * torch.sign(parameter.data)

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

如果需要,请在此答案或相应的 PyTorch 文档中阅读有关钩子的更多信息。

而且用法也很简单(应该使用梯度累积和 PyTorch 层):

layer = L1(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3))

边注

此外,作为旁注, L1正则化没有实现,因为它实际上并没有引起稀疏性(丢失引用,我认为这是 PyTorch 存储库上的一些 GitHub 问题,如果有人有它,请编辑),正如权重等于零所理解的那样.

更常见的是,如果权重值达到某个小的预定义量级(例如0.001 ),则会对其进行阈值处理(只需为其分配零值)

对于 L2 正则化,

l2_lambda = 0.01
l2_reg = torch.tensor(0.)
for param in model.parameters():
    l2_reg += torch.norm(param)
loss += l2_lambda * l2_reg

参考:

开箱即用的 L2 正则化

是的,pytorch优化器有一个名为weight_decay的参数,它对应于 L2 正则化因子:

sgd = torch.optim.SGD(model.parameters(), weight_decay=weight_decay)

L1 正则化实现

L1 没有类似的论点,但是这很容易手动实现:

loss = loss_fn(outputs, labels)
l1_lambda = 0.001
l1_norm = sum(torch.linalg.norm(p, 1) for p in model.parameters())

loss = loss + l1_lambda * l1_norm

L2 的等效手动实现将是:

l2_norm = sum(torch.linalg.norm(p, 2) for p in model.parameters())

资料来源: 使用 PyTorch 进行深度学习(8.5.2)

对于 L1 正则化和仅包括weight

L1_reg = torch.tensor(0., requires_grad=True)
for name, param in model.named_parameters():
    if 'weight' in name:
        L1_reg = L1_reg + torch.norm(param, 1)

total_loss = total_loss + 10e-4 * L1_reg

有趣的torch.norm在 CPU 上较慢,在 GPU 上较直接方法更快。

import torch
x = torch.randn(1024,100)
y = torch.randn(1024,100)

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出去:

1000 loops, best of 3: 910 µs per loop
1000 loops, best of 3: 1.76 ms per loop

另一方面:

import torch
x = torch.randn(1024,100).cuda()
y = torch.randn(1024,100).cuda()

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出去:

10000 loops, best of 3: 50 µs per loop
10000 loops, best of 3: 26 µs per loop

扩展好的答案:正如所说,如果您使用没有动量的普通 SGD,则添加到损失中的 L2 范数相当于权重衰减。 否则,例如对于 Adam,它并不完全相同。 AdamW 论文 [1] 指出权重衰减实际上更稳定。 这就是为什么你应该使用权重衰减,这是优化器的一个选项。 并考虑使用AdamW而不是Adam

另请注意,您可能不希望所有参数 ( model.parameters() ) 的权重衰减,而只是在一个子集上。 有关示例,请参见此处:

[1]解耦权重衰减正则化 (AdamW),2017

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM