简体   繁体   中英

Code a tensor view layer in nn.sequential

I have a sequential container and inside I want to use the Tensor.view function. Thus my current solution looks like this:

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.my_shape = args

    def forward(self, x):
        return x.view(self.my_shape)

and in my AutoEncoder class I have:

self.decoder = nn.Sequential(
                torch.nn.Linear(self.bottleneck_size, 4096*2),
                Reshape(-1, 128, 8, 8),
                
                nn.UpsamplingNearest2d(scale_factor=2), 
                ...

Is there a way to reshape the tensor directly in the sequential block so that I do not need to use the externally created Reshape class? Thank you

You can use UNFLATTEN layer, from Pytorch docs :

Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.

So you would have:

self.decoder = nn.Sequential(
            torch.nn.Linear(self.bottleneck_size, 4096*2),
            nn.Unflatten(1, (1, 128, 8, 8)), # The first parameters is the dimension you would like to unflatten, note that dimension 0 is usually your batch size. So here we need dimension 1.
            # These alsos work
            # nn.Unflatten(1, (-1, 128, 8, 8)), 
            # nn.Unflatten(1, (128, 8, 8)),                 nn.UpsamplingNearest2d(scale_factor=2), 
            ...

You should also check this discussion on Pytorch forum if you have not already. Also here is how torchvision models used to be implemented in Pytorch. You can see they have separated Tensor.view from rest of Sequential modules and applied it in the forward . The current version of same code now uses flatten , which means using unflatten is reasonable here.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM