繁体   English   中英

使用来自 __init__ 的参数定义 python 类方法

[英]Defining python class method using arguments from __init__

我想要达到的目标

我想定义一个有A和B两种模式的类,这样类的forward方法会相应的改变。

class MyClass():
    def __init__(self, constant):
        self.constant=constant
            
    def forward(self, x1,x2,function):
        if function=='A':
            return x1+self.constant
        
        elif function=='B':
            return x1*x2+self.constant
        
        else:
            print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2, None, 'A')
output>>>4
model2 = MyClass(2)
model2.forward(2, 2, 'B')
output>>>6

它可以工作,但不是最优的,因为每次调用forward方法时,它都会检查要使用哪个函数。 但是,forward 函数已经设置并且一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward使用哪个函数是非常多余的。 (对于那些注意到这一点的人,我正在使用 PyTorch 编写我的神经网络模型,两个模型共享 90% 的网络架构,唯一的 10% 差异是它们的前馈方式)。

我想要的版本

我想在定义类的时候设置forward方法,这样我就可以实现这个

model1 = MyClass(2, 'A')
model1.forward(2)
output>>>4

model2 = MyClass(2, 'B')
model2.forward(2, 2)
output>>>6

所以我把我的班级改写为:

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant # There would be a lot of shared parameters for the two methods
        self.function=function # This controls the feedforward method of this class
        
    if self.function=='A':
        def forward(self, x1):
            return x1+self.constant
        
    elif self.function=='B':
        def forward(self, x1, x2):
            return x1*x2+self.constant
        
    else:
        print('please provide the correct function')

但是,它给了我以下错误。

NameError: 名称 'self' 未定义

我如何编写基于__init__参数定义不同forward方法的类?

您一直在尝试用您的代码重新定义,这样每个新对象都会在之前和之后更改所有对象的 forward 定义。 幸运的是,你没有弄清楚如何做到这一点。

相反,使所选函数成为对象的属性。 对您想要的两个函数进行编码,然后在创建每个实例时分配所需的变体。

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant
        if function == 'A':
            self.forward = self.forwardA
        elif function=='B':
            self.forward = self.forwardB
        else:
            print('please provide the correct function')
            
    def forwardA(self, x1):
        return x1+self.constant
        
    def forwardB(self, x1, x2):
        return x1*x2+self.constant

# Main
model1 = MyClass(2, 'A')
print(model1.forward(2))

model2 = MyClass(2, 'B')
print(model2.forward(2, 2))

输出:

4
6

您还可以尝试分解出一个基类。 它可能会与mypy一起玩得更好,并且可以更轻松地不混淆您正在使用的任何类。

class MyClassBase():                                      
    def __init__(self, constant):                         
         self.constant=constant                           
                                                          
    def forward(self, *args, **kwargs):              
        raise NotImplementedError('use a derived class')  
                                                          
class MyClassA(MyClassBase):                              
    def __init__(self, constant):                         
        super().__init__(constant)                        
                                                          
    def forward(self,x1):                                 
        return x1 + self.constant                         
                                                          
class MyClassB(MyClassBase):                              
    def __init__(self, constant):                         
        super().__init__(constant)                        
                                                          
    def forward(self, x1, x2):                            
        return x1*x2 + self.constant                      
                                                          
a = MyClassA(2)                                           
b = MyClassB(2)                                           
                                                          
print(a.forward(2))                                       
print(b.forward(2,2))                                     

我想说的是,如果您要在类方法中跨返回使用常量变量名称,我们可以通过这种方式定义它来调整逻辑。 (我们可以用不同的方式来做到这一点,只使用 Params)希望这看起来不错。

class MyClass():
    def __init__(self, constant, function):
        self.constant=constant
        self.function=function
    
    def forward(self, x1 = None, x2 = None):
        if self.function=='A':
            return x1+self.constant
        elif self.function=='B':
            return x1*x2+self.constant
        else:
            print('please provide the correct function')

model1 = MyClass(2, 'A')
model1.forward(2)

model2 = MyClass(2, 'B')
model2.forward(2, 2)

定义

在这种情况下,我也会去继承; 给定你想要的:

model2 = MyClass(2, 'B')

这样做会更容易(并且对其他人更具可读性):

model2 = MyClassB(2)

鉴于此,类似于@Nathan Chappell 在他的回答中提供的内容,但较短(例如,无需重新定义__init__ ):

import torch


class Base(torch.nn.Module):
    def __init__(self, constant):
        super().__init__()
        self.constant = constant


class MyClassA(Base):
    def forward(self, x1):
        return x1 + self.constant


class MyClassB(Base):
    def forward(self, x1, x2):
        return x1 * x2 + self.constant

打电话

您应该使用torch.nn.Module__call__方法而不是forward因为它可以与钩子一起正常工作(请参阅此答案),因此它应该是:

model1 = MyClassA(5)
model1(torch.randn(10, 5))

代替:

model1 = MyClassA(5)
model1.forward(torch.randn(10, 5))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM