繁体   English   中英

PyTorch中如何高效实现非全连接线性层?

[英]How to efficiently implement a non-fully connected Linear Layer in PyTorch?

我制作了一个我正在尝试实现的缩小版本的示例图:

网络图

所以顶部的两个输入节点只与顶部的三个 output 节点完全连接,同样的设计适用于底部的两个节点。 到目前为止,我已经提出了两种在 PyTorch 中实现这一点的方法,这两种方法都不是最佳的。

第一个是创建一个包含许多较小线性层的 nn.ModuleList,并在前向传递期间通过它们迭代输入。 对于图表的示例,它看起来像这样:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Module([nn.Linear(2, 3) for i in range(2)])
  
  def forward(self, input):
    output = torch.zeros(2, 3)
    for i in range(2):
      output[i, :] = self.layers[i](input.view(2, 2)[i, :])
    return output.flatten()

所以这完成了图中的网络,主要问题是它非常慢。 我认为这是因为 PyTorch 必须按顺序处理 for 循环,并且不能并行处理输入张量。

为了“矢量化”模块以便 PyTorch 可以更快地运行它,我有这个实现:

class Module(nn.Module):
  def __init__(self):
    self.layer = nn.Linear(4, 6)
    self.mask = # create mask of ones and zeros to "block" certain layer connections
  
  def forward(self, input):
    prune.custom_from_mask(self.layer, name='weight', mask=self.mask)
    return self.layer(input)

这也完成了图的网络,通过使用权重修剪来确保完全连接层中的某些权重始终为零(例如,连接顶部输入节点和底部输出节点的权重将始终为零,因此它有效地“断开连接”) . 这个模块比前一个模块快得多,因为没有 for 循环。 现在的问题是这个模块占用了更多的 memory。 这可能是因为即使大多数层的权重为零,PyTorch 仍然将网络视为存在。 这种实现基本上保持了比它需要的更多的权重。

有没有人遇到过这个问题并想出一个有效的解决方案?

如果权重共享没问题,那么一维卷积应该可以解决问题:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=3, kernel_size=1)
    self._n_splits = 2

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.view(B, C//self._n_splits, -1))
    return output.view(B, C)

如果权重共享不好,那么您可以使用组卷积: self.layers = nn.Conv1d(in_channels=4, out_channels=3, kernel_size=1, stride=1, groups=2) 但是,我不确定这是否可以实现任意数量的通道拆分,您可以查看文档: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html

一维卷积是输入所有通道上的全连接层。 组卷积会将通道分成组并对它们执行单独的卷积操作(这是您想要的)。

实现将类似于:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=3, kernel_size=1, groups=2)

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.unsqueeze(-1)
    return output.squeeze()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM