简体   繁体   中英

How to pass parameters to forward function of my torch nn.module from skorch.NeuralNetClassifier.fit()

I have extended nn.Module to implement my network whose forward function is like this ...

def forward(self, X, **kwargs):

    batch_size, seq_len = X.size()

    length = kwargs['length']
    embedded = self.embedding(X) # [batch_size, seq_len, embedding_dim]
    if self.use_padding:
        if length is None:
            raise AttributeError("Length must be a tensor when using padding")
        embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True)
        #print("Size of Embedded packed", embedded[0].size())


    hidden, cell = self.init_hidden(batch_size)
    if self.rnn_unit == 'rnn':
        out, _ = self.rnn(embedded, hidden)
    elif self.rnn_unit == 'lstm':
        out, (hidden, cell) = self.rnn(embedded, (hidden, cell))


    # unpack if padding was used
    if self.use_padding:
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first = True)

I initialized a skorch NeuralNetClassifier like this,

net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=Adam, 
    max_epochs=8, 
    lr=0.01, 
    batch_size=32
)

Now if I call net.fit(X, y, length=X_len) it throws an error

TypeError: __call__() got an unexpected keyword argument 'length'

According to the documentation fit function expects a fit_params dictionary,

 **fit_params : dict Additional parameters passed to the ``forward`` method of the module and to the ``self.train_split`` call. 

and the source code always send my parameters to train_split where obviously my keyword argument would not be recognized.

Is there any way around to pass the arguments to my forward function?

The fit_params parameter is intended for passing information that is relevant to data splits and the model alike, like split groups.

In your case, you are passing additional data to the module via fit_params which is not what it is intended for. In fact, you could easily run into trouble doing this if you, for example, enable batch shuffling on the train data loader since then your lengths and your data are misaligned.

The best way to do this is already described in the answer to your question on the issue tracker :

X_dict = {'X': X, 'length': X_len}
net.fit(X_dict, y)

Since skorch supports dict s you can simply add the length's to your input dict and have it both passed to the module, nicely batched and passed through the same data loader. In your module you can then access it via the parameters in forward :

def forward(self, X, length):
     return ...

Further documentation of this behaviour can be found in the docs .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM