简体   繁体   中英

implement dropout layer using nn.Sequential()

I am trying to implement a Dropout layer using pytorch as follows:

class DropoutLayer(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p

    def forward(self, input):
        if self.training:
            u1 = (np.random.rand(*input.shape)<self.p) / self.p
            u1 *= u1
            return u1
        else:
            input *= self.p

And then calling a simple NN.sequential:

model = nn.Sequential(nn.Linear(input_size,num_classes), DropoutLayer(.7), nn.Flatten())

opt = torch.optim.Adam(model.parameters(), lr=0.005)
train(model, opt, 5) #train(model, optimizer, epochs #)

But I'm getting the following error:

TypeError: flatten() takes at most 1 argument (2 given)

Not sure what I'm doing wrong. Still new to pytorch. Thanks.

In the forward function of your DropoutLayer , when you enter the else branch, there is no return. Therefore the following layer ( flatten ) will have no input. However, as emphasized in the comments, that's not the actual problem.

The actual problem is that you are passing a numpy array to your Flatten layer. A Minimal code to reproduce the problem would be :

nn.Flatten()(np.random.randn(5,5))
>>> TypeError: flatten() takes at most 1 argument (2 given)

However, I cannot explain why this layer behaves like that on a numpy tensor, the behavior of the flatten function being much more understandable. I don't know what additional operations the layer performs.

torch.flatten(np.random.randn(5,5))
>>> TypeError: flatten(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

Why this error is raised by your code is because in the forward pass, you create a numpy tensor, perform some operations, and return it instead of returning a tensor. If I may, you don't even touch the actual input tensor (in the first branch)

代码已解决,解决方案只是首先调用nn.Sequential(Flatten())将输入矩阵增加到nx786维度。

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM