简体   繁体   中英

How to implement current pytorch activation functions with parameters?

I am looking for a simple way to use an activation function which exist in the pytorch library, but using some sort of parameter. for example:

Tanh(x/10)

The only way I came up with looking for solution was implementing the custom function completely from scratch. Is there any better/more elegant way to do this?

edit:

I am looking for some way to append to my model the function Tanh(x/10) rather than plain Tanh(x). Here is the relevant code block:

    self.model = nn.Sequential()
    for i in range(len(self.layers)-1):
        self.model.add_module("linear_layer_" + str(i), nn.Linear(self.layers[i], self.layers[i + 1]))
        if activations == None:
            self.model.add_module("activation_" + str(i), nn.Tanh())
        else:
            if activations[i] == "T":
                self.model.add_module("activation_" + str(i), nn.Tanh())
            elif activations[i] == "R":
                self.model.add_module("activation_" + str(i), nn.ReLU())
            else:
                #no activation
                pass

Instead of defining it as a specific function, you could inline it in a custom layer .

For instance your solution could look like:


import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)
        self.fc3 = nn.Softmax()

    def forward(self, x):
        return self.fc3(self.fc2(torch.tanh(self.fc1(x)/10)))

where torch.tanh(output/10) is inlined in the forward function of your module.

You can create a layer with the multiplying parameter:

import torch
import torch.nn as nn

class CustomTanh(nn.Module):

    #the init method takes the parameter:
    def __init__(self, multiplier):
        self.multiplier = multiplier

    #the forward calls it:
    def forward(self, x):
        x = self.multiplier * x
        return torch.tanh(x)

Add it to your models with CustomTanh(1/10) instead of nn.Tanh() .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM