簡體   English   中英

PyTorch 中的 L1/L2 正則化

[英]L1/L2 regularization in PyTorch

如何在不手動計算的情況下在 PyTorch 中添加 L1/L2 正則化?

使用weight_decay > 0進行 L2 正則化:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

請參閱文檔 為 L2 正則化的優化器添加weight_decay參數。

以前的答案雖然在技術上是正確的,但在性能方面效率低下並且不是太模塊化(很難在每層的基礎上應用,例如keras層提供的)。

PyTorch L2 實現

為什么 PyTorch 在torch.optim.Optimizer實例中實現L2

我們看一下torch.optim.SGD源碼(目前為功能優化程序),特別是這部分:

for i, param in enumerate(params):
    d_p = d_p_list[i]
    # L2 weight decay specified HERE!
    if weight_decay != 0:
        d_p = d_p.add(param, alpha=weight_decay)
  • 可以看到, d_p (參數的導數,梯度)被修改並重新分配以加快計算速度(不保存臨時變量)
  • 它具有O(N)復雜度,沒有任何復雜的數學,例如pow
  • 它不涉及autograd擴展圖形而無需任何需要

將其與O(n) **2操作、加法以及參與反向傳播進行比較。

數學

讓我們看看具有alpha正則化因子的L2方程(L1 ofc 也可以這樣做):

L2

如果我們用L2正則化 wrt 參數w對任何損失求導(它與損失無關),我們得到:

L2 衍生

所以它只是為每個權重的梯度添加了alpha * weight weight! 這正是 PyTorch 在上面所做的!

L1 正則化層

使用這個(和一些 PyTorch 魔法),我們可以提出非常通用的 L1 正則化層,但讓我們先看看L1的一階導數( sgn是符號函數,返回1表示正輸入,返回-1表示負, 0表示0 ) :

L1 導數

帶有WeightDecay接口的完整代碼位於torchlayers 第三方庫中,提供了諸如僅對權重/偏差/特定命名的參數進行正則化等內容(免責聲明:我是作者),但下面概述的想法的本質(見評論):

class L1(torch.nn.Module):
    def __init__(self, module, weight_decay):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not dependent on backprop incoming values, placeholder
    def _weight_decay_hook(self, *_):
        for param in self.module.parameters():
            # If there is no gradient or it was zeroed out
            # Zeroed out using optimizer.zero_grad() usually
            # Turn on if needed with grad accumulation/more safer way
            # if param.grad is None or torch.all(param.grad == 0.0):

            # Apply regularization on it
            param.grad = self.regularize(param)

    def regularize(self, parameter):
        # L1 regularization formula
        return self.weight_decay * torch.sign(parameter.data)

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

如果需要,請在此答案或相應的 PyTorch 文檔中閱讀有關鈎子的更多信息。

而且用法也很簡單(應該使用梯度累積和 PyTorch 層):

layer = L1(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3))

邊注

此外,作為旁注, L1正則化沒有實現,因為它實際上並沒有引起稀疏性(丟失引用,我認為這是 PyTorch 存儲庫上的一些 GitHub 問題,如果有人有它,請編輯),正如權重等於零所理解的那樣.

更常見的是,如果權重值達到某個小的預定義量級(例如0.001 ),則會對其進行閾值處理(只需為其分配零值)

對於 L2 正則化,

l2_lambda = 0.01
l2_reg = torch.tensor(0.)
for param in model.parameters():
    l2_reg += torch.norm(param)
loss += l2_lambda * l2_reg

參考:

開箱即用的 L2 正則化

是的,pytorch優化器有一個名為weight_decay的參數,它對應於 L2 正則化因子:

sgd = torch.optim.SGD(model.parameters(), weight_decay=weight_decay)

L1 正則化實現

L1 沒有類似的論點,但是這很容易手動實現:

loss = loss_fn(outputs, labels)
l1_lambda = 0.001
l1_norm = sum(torch.linalg.norm(p, 1) for p in model.parameters())

loss = loss + l1_lambda * l1_norm

L2 的等效手動實現將是:

l2_norm = sum(torch.linalg.norm(p, 2) for p in model.parameters())

資料來源: 使用 PyTorch 進行深度學習(8.5.2)

對於 L1 正則化和僅包括weight

L1_reg = torch.tensor(0., requires_grad=True)
for name, param in model.named_parameters():
    if 'weight' in name:
        L1_reg = L1_reg + torch.norm(param, 1)

total_loss = total_loss + 10e-4 * L1_reg

有趣的torch.norm在 CPU 上較慢,在 GPU 上較直接方法更快。

import torch
x = torch.randn(1024,100)
y = torch.randn(1024,100)

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出去:

1000 loops, best of 3: 910 µs per loop
1000 loops, best of 3: 1.76 ms per loop

另一方面:

import torch
x = torch.randn(1024,100).cuda()
y = torch.randn(1024,100).cuda()

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

出去:

10000 loops, best of 3: 50 µs per loop
10000 loops, best of 3: 26 µs per loop

擴展好的答案:正如所說,如果您使用沒有動量的普通 SGD,則添加到損失中的 L2 范數相當於權重衰減。 否則,例如對於 Adam,它並不完全相同。 AdamW 論文 [1] 指出權重衰減實際上更穩定。 這就是為什么你應該使用權重衰減,這是優化器的一個選項。 並考慮使用AdamW而不是Adam

另請注意,您可能不希望所有參數 ( model.parameters() ) 的權重衰減,而只是在一個子集上。 有關示例,請參見此處:

[1]解耦權重衰減正則化 (AdamW),2017

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM