繁体   English   中英

如何将 Pytorch 的深度学习可训练参数分解为两部分?

[英]How to decompose deep learning trainable parameter into two parts for Pytorch?

当我阅读论文“Federated Semi-Supervised Learning with Inter-Client Consistency & Disjoint Learning”时,我遇到了这个问题。 我很想知道脱节的学习。 作者说'对于给定的 model,它应该分解为 model = a + b的两部分。 当我们在数据集 A 上训练 model 时,我们只更新a并保持b不变。 当我们在数据集 B 上训练 model 时,我们只更新b并保持一个常数。 我们如何使用 pytorch 做到这一点?

非常感谢你的帮助!!!

好吧,现在更有意义了。 在超参数的上下文中使用“+”确实是有道理的,但它令人困惑。 无论如何,这就是我将如何做到的。 让我们制作一个 model ,它必须连续层,它们的权重分别是 σ 和 ψ (这是一个虚拟示例,不要尝试实际训练它)

class DecomposedModel(nn.Module):
    def __init__(self):
        self.layer_A = nn.Linear(2, 3)
        self.layer_B = nn.Linear(3, 5)

    def forward(self, input_t):
        """ input_t expected shape is (B, 2), output will be (B,5) """
        return self.layer_B(nn.ReLU(self.layer_A(input_t)))

    def parameters_A(self):
        """ returns σ parameters"""
        return self.layer_A.parameters()

    def parameters_B(self):
        """ returns ψ parameters """
        return self.layer_B.parameters()

model = DecomposedModel()
opt_A = Adam(model.parameters_A(), lr=1e-3)
opt_B = Adam(model.parameters_B(), lr=1e-3)

使用这种结构,您可以选择在opt_Aopt_B上交替应用step ,一次只训练一层因此当您在数据集 A 上进行训练时,您调用opt_A.step() ,在数据集 B 上调用opt_B.step() (当然还有对zero_grad的相应调用)

论文作者的做法:用两个变量的sigma和phi创建一个自定义层。

class CustomLinear(nn.Module):
    def __init__(self, in_shape: int, out_shape: int, supervised: bool):
        super().__init__()

        self.sigma = nn.Parameter(torch.randn(
            in_shape, out_shape), requires_grad=supervised)
        self.phi = nn.Parameter(torch.randn(
            in_shape, out_shape), requires_grad=not supervised)

        self.sigma_bias = nn.Parameter(
            torch.randn(out_shape), requires_grad=supervised)
        self.phi_bias = nn.Parameter(
            torch.randn(out_shape), requires_grad=not supervised)

    def forward(self, X):
        theta = self.phi + self.sigma
        bias = self.phi_bias + self.sigma_bias
        return X @ theta + bias

使用 sigma 和 phi,您还可以使用两个优化器来控制它,而不是requires_grad

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM