[英]Learnable LeakyReLU activation function with Pytorch
I'm trying to write a class for Invertible trainable LeakyReLu in which the model modifies the negative_slope in each iteration,我正在尝试为可逆可训练 LeakyReLu 编写 class ,其中 model 在每次迭代中修改negative_slope,
class InvertibleLeakyReLU(nn.Module):
def __init__(self, negative_slope):
super(InvertibleLeakyReLU, self).__init__()
self.negative_slope = torch.tensor(negative_slope, requires_grad=True)
def forward(self, input, logdet = 0, reverse = False):
if reverse == True:
input = torch.where(input>=0.0, input, input *(1/self.negative_slope))
log = - torch.where(input >= 0.0, torch.zeros_like(input), torch.ones_like(input) * math.log(self.negative_slope))
logdet = (sum(log, dim=[1, 2, 3]) +logdet).mean()
return input, logdet
else:
input = torch.where(input>=0.0, input, input *(self.negative_slope))
log = torch.where(input >= 0.0, torch.zeros_like(input), torch.ones_like(input) * math.log(self.negative_slope))
logdet = (sum(log, dim=[1, 2, 3]) +logdet).mean()
return input, logdet
However I set requires_grad=True
, the negative slope wouldn't update.但是我设置requires_grad=True
,负斜率不会更新。 Are there any other points that I must modify?还有其他需要修改的地方吗?
Does your optimizer know it should update InvertibleLeakyReLU.negative_slope
?您的优化器是否知道它应该更新InvertibleLeakyReLU.negative_slope
?
My guess is - no:我的猜测是 - 不:
self.negative_slope
is not defined as nn.Parameter
, and therefore, by default, when you initialize your optimizer with model.parameters()
negative_slope
is not one of the optimization parameters. self.negative_slope
未定义为nn.Parameter
,因此,默认情况下,当您使用model.parameters()
初始化优化器时, negative_slope
不是优化参数之一。
You can either define negative_slope
as a nn.Parameter
:您可以将negative_slope
定义为nn.Parameter
:
self.negative_slope = nn.Parameter(data=torch.tensor(negative_slope), requires_grad=True)
Or, explicitly pass negative_slope
from all InvertibleLeakyReLU
in your model to the optimizer.或者,将 model 中所有InvertibleLeakyReLU
中的negative_slope
显式传递给优化器。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.