简体   繁体   English

PyTorch逐元素过滤层

[英]PyTorch element-wise filter layer

Hi, I want to add element-wise multiplication layer to duplicate the input to multi-channels like this figure. 嗨,我想添加逐元素乘法层,以将输入复制到多通道,如图所示。 (So, the input size M x N and multiplication filter size M x N is same), as illustrated in this figure (因此,输入大小M x N和乘法滤波器大小M x N相同),如图所示

I want to add custom initialization value to filter, and also want them to get gradient while training. 我想添加自定义初始化值进行过滤,还希望他们在训练时获得渐变。 However, I can't find element-wise filter layer in PyTorch. 但是,我在PyTorch中找不到按元素分类的滤镜层。 Can I make it? 我能做到吗? Or is it just impossible in PyTorch? 还是在PyTorch中是不可能的?

In pytorch you can always implement your own layers, by making them subclasses of nn.Module . 在pytorch中,您始终可以通过将其nn.Module子类来实现自己的图层。 You can also have trainable parameters in your layer, by using nn.Parameter . 您还可以使用nn.Parameter在您的图层中设置可训练的参数。
Possible implementation of such layer might look like 这种层的可能实现可能看起来像

import torch
from torch import nn

class TrainableEltwiseLayer(nn.Module)
  def __init__(self, n, h, w):
    super(TrainableEltwiseLayer, self).__init__()
    self.weights = nn.Parameter(torch.Tensor(1, n, h, w))  # define the trainable parameter

  def forward(self, x):
    # assuming x is of size b-1-h-w
    return x * self.weights  # element-wise multiplication

You still need to worry about initializing the weights. 您仍然需要担心初始化权重。 look into nn.init on ways to init weights. 查看nn.init有关初始化权重的方法。 Usually one init the weights of all the net prior to training and prior to loading any stored model (so partially trained models can override random init). 通常是在训练之前和加载任何存储的模型之前初始化所有网络的权重(因此,部分训练的模型可以覆盖随机初始化)。 Something like 就像是

model = mymodel(*args, **kwargs)  # instantiate a model
for m in model.modules():
  if isinstance(m, nn.Conv2d):
     nn.init.normal_(m.weights.data)  # init for conv layers
  if isinstance(m, TrainableEltwiseLayer):
     nn.init.constant_(m.weights.data, 1)  # init your weights here...

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM