简体   繁体   English

使用来自 __init__ 的参数定义 python 类方法

[英]Defining python class method using arguments from __init__

What I want to achieve我想要达到的目标

I want to define a class with two modes A and B, so that the forward method of the class changes accordingly.我想定义一个有A和B两种模式的类,这样类的forward方法会相应的改变。

class MyClass():
    def __init__(self, constant):
        self.constant=constant
            
    def forward(self, x1,x2,function):
        if function=='A':
            return x1+self.constant
        
        elif function=='B':
            return x1*x2+self.constant
        
        else:
            print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2, None, 'A')
output>>>4
model2 = MyClass(2)
model2.forward(2, 2, 'B')
output>>>6

It works, but it is not optimal, since every time when calling the forward method, it will check which function to use.它可以工作,但不是最优的,因为每次调用forward方法时,它都会检查要使用哪个函数。 However, the forward function is already set and will never be changed once the class is define, therefore, checking which function to use inside forward is super redundant in my case.但是,forward 函数已经设置并且一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward使用哪个函数是非常多余的。 (For those who notice this, I am writing my neural network model using PyTorch, two models share 90% of the network architecture, the only 10% differences is the way they do feedforward). (对于那些注意到这一点的人,我正在使用 PyTorch 编写我的神经网络模型,两个模型共享 90% 的网络架构,唯一的 10% 差异是它们的前馈方式)。

My desired version我想要的版本

I want to set the forward method when the class is defined, so that I can achieve this我想在定义类的时候设置forward方法,这样我就可以实现这个

model1 = MyClass(2, 'A')
model1.forward(2)
output>>>4

model2 = MyClass(2, 'B')
model2.forward(2, 2)
output>>>6

So I rewrote my class to be:所以我把我的班级改写为:

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant # There would be a lot of shared parameters for the two methods
        self.function=function # This controls the feedforward method of this class
        
    if self.function=='A':
        def forward(self, x1):
            return x1+self.constant
        
    elif self.function=='B':
        def forward(self, x1, x2):
            return x1*x2+self.constant
        
    else:
        print('please provide the correct function')

However, it gives me the following error.但是,它给了我以下错误。

NameError: name 'self' is not defined NameError: 名称 'self' 未定义

How do I write the class for that it defines different forward method based on the args from __init__ ?我如何编写基于__init__参数定义不同forward方法的类?

You have been trying to redefine the class with your code, such that each new object would change the definition of forward for all objects, before and after.您一直在尝试用您的代码重新定义,这样每个新对象都会在之前和之后更改所有对象的 forward 定义。 Fortunately, you didn't figure out how to do that.幸运的是,你没有弄清楚如何做到这一点。

Instead, make the chosen function an attribute of the object.相反,使所选函数成为对象的属性。 Code the two functions you want, and then assign the desired variant as you create each instance.对您想要的两个函数进行编码,然后在创建每个实例时分配所需的变体。

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant
        if function == 'A':
            self.forward = self.forwardA
        elif function=='B':
            self.forward = self.forwardB
        else:
            print('please provide the correct function')
            
    def forwardA(self, x1):
        return x1+self.constant
        
    def forwardB(self, x1, x2):
        return x1*x2+self.constant

# Main
model1 = MyClass(2, 'A')
print(model1.forward(2))

model2 = MyClass(2, 'B')
print(model2.forward(2, 2))

Output:输出:

4
6

You can also try factoring out a base class.您还可以尝试分解出一个基类。 It will probably play nicer with mypy and will make it easier to not mix up whatever class you're using.它可能会与mypy一起玩得更好,并且可以更轻松地不混淆您正在使用的任何类。

class MyClassBase():                                      
    def __init__(self, constant):                         
         self.constant=constant                           
                                                          
    def forward(self, *args, **kwargs):              
        raise NotImplementedError('use a derived class')  
                                                          
class MyClassA(MyClassBase):                              
    def __init__(self, constant):                         
        super().__init__(constant)                        
                                                          
    def forward(self,x1):                                 
        return x1 + self.constant                         
                                                          
class MyClassB(MyClassBase):                              
    def __init__(self, constant):                         
        super().__init__(constant)                        
                                                          
    def forward(self, x1, x2):                            
        return x1*x2 + self.constant                      
                                                          
a = MyClassA(2)                                           
b = MyClassB(2)                                           
                                                          
print(a.forward(2))                                       
print(b.forward(2,2))                                     

I would say if you're going to use the constant variable names across returns in a class method, we can define it in this way to tweak the logic.我想说的是,如果您要在类方法中跨返回使用常量变量名称,我们可以通过这种方式定义它来调整逻辑。 (We can do this in a different way, play only with Params) Hope this looks good. (我们可以用不同的方式来做到这一点,只使用 Params)希望这看起来不错。

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant
        self.function=function
    
    def forward(self, x1 = None, x2 = None):
        if self.function=='A':
            return x1+self.constant
        elif self.function=='B':
            return x1*x2+self.constant
        else:
            print('please provide the correct function')

model1 = MyClass(2, 'A')
model1.forward(2)

model2 = MyClass(2, 'B')
model2.forward(2, 2)

Definition定义

I would also go for inheritance in this case;在这种情况下,我也会去继承; given what you want:给定你想要的:

model2 = MyClass(2, 'B')

It would be much easier (and way more readable for others) to just do:这样做会更容易(并且对其他人更具可读性):

model2 = MyClassB(2)

Given that, similar to what @Nathan Chappell provided in his answer but shorter (no need for redefining __init__ for example):鉴于此,类似于@Nathan Chappell 在他的回答中提供的内容,但较短(例如,无需重新定义__init__ ):

import torch


class Base(torch.nn.Module):
    def __init__(self, constant):
        super().__init__()
        self.constant = constant


class MyClassA(Base):
    def forward(self, x1):
        return x1 + self.constant


class MyClassB(Base):
    def forward(self, x1, x2):
        return x1 * x2 + self.constant

Calling打电话

You should use torch.nn.Module 's __call__ method instead of forward as it works correctly with hooks (see this answer ), hence it should be:您应该使用torch.nn.Module__call__方法而不是forward因为它可以与钩子一起正常工作(请参阅此答案),因此它应该是:

model1 = MyClassA(5)
model1(torch.randn(10, 5))

instead of:代替:

model1 = MyClassA(5)
model1.forward(torch.randn(10, 5))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM