繁体   English   中英

使用 Pytorch 提取自动编码器隐藏层的特征

[英]Extracting features of the hidden layer of an autoencoder using Pytorch

我正在按照本教程训练自动编码器。

训练进行得很顺利。 接下来,我有兴趣从隐藏层(编码器和解码器之间)提取特征。

我该怎么做?

最干净和最直接的方法是添加用于创建部分输出的方法——这甚至可以在经过训练的 model 上进行后验。

from torch import Tensor

class AE(nn.Module):
    def __init__(self, **kwargs):
        ...

    def encode(self, features: Tensor) -> Tensor:
        h = torch.relu(self.encoder_hidden_layer(features))
        return torch.relu(self.encoder_output_layer(h))

    def decode(self, encoded: Tensor) -> Tensor:
        h = torch.relu(self.decoder_hidden_layer(encoded))
        return torch.relu(self.decoder_output_layer(h))

    def forward(self, features: Tensor) -> Tensor:
        encoded = self.encode(features)
        return self.decode(encoded)

您现在可以通过简单地使用相应的输入张量调用 encode 来查询 model 的编码器隐藏状态。

如果您不想在基础 class 中添加任何方法(我不明白为什么),您也可以编写一个外部 function:

def get_encoder_state(model: AE, features: Tensor) -> Tensor:
   return torch.relu(model.encoder_output_layer(torch.relu(model.encoder_hidden_layer(features))))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM