![](/img/trans.png)
[英]Defining a class' __init__ arguments dynamically (using a previously defined dictionary)
[英]Defining python class method using arguments from __init__
我想定义一个有A和B两种模式的类,这样类的forward
方法会相应的改变。
class MyClass():
def __init__(self, constant):
self.constant=constant
def forward(self, x1,x2,function):
if function=='A':
return x1+self.constant
elif function=='B':
return x1*x2+self.constant
else:
print('please provide the correct function')
model1 = MyClass(2)
model1.forward(2, None, 'A')
output>>>4
model2 = MyClass(2)
model2.forward(2, 2, 'B')
output>>>6
它可以工作,但不是最优的,因为每次调用forward
方法时,它都会检查要使用哪个函数。 但是,forward 函数已经设置并且一旦定义了类就永远不会更改,因此,在我的情况下,检查在forward
使用哪个函数是非常多余的。 (对于那些注意到这一点的人,我正在使用 PyTorch 编写我的神经网络模型,两个模型共享 90% 的网络架构,唯一的 10% 差异是它们的前馈方式)。
我想在定义类的时候设置forward
方法,这样我就可以实现这个
model1 = MyClass(2, 'A')
model1.forward(2)
output>>>4
model2 = MyClass(2, 'B')
model2.forward(2, 2)
output>>>6
所以我把我的班级改写为:
class MyClass():
def __init__(self, constant, function):
self.constant=constant # There would be a lot of shared parameters for the two methods
self.function=function # This controls the feedforward method of this class
if self.function=='A':
def forward(self, x1):
return x1+self.constant
elif self.function=='B':
def forward(self, x1, x2):
return x1*x2+self.constant
else:
print('please provide the correct function')
但是,它给了我以下错误。
NameError: 名称 'self' 未定义
我如何编写基于__init__
参数定义不同forward
方法的类?
您一直在尝试用您的代码重新定义类,这样每个新对象都会在之前和之后更改所有对象的 forward 定义。 幸运的是,你没有弄清楚如何做到这一点。
相反,使所选函数成为对象的属性。 对您想要的两个函数进行编码,然后在创建每个实例时分配所需的变体。
class MyClass():
def __init__(self, constant, function):
self.constant=constant
if function == 'A':
self.forward = self.forwardA
elif function=='B':
self.forward = self.forwardB
else:
print('please provide the correct function')
def forwardA(self, x1):
return x1+self.constant
def forwardB(self, x1, x2):
return x1*x2+self.constant
# Main
model1 = MyClass(2, 'A')
print(model1.forward(2))
model2 = MyClass(2, 'B')
print(model2.forward(2, 2))
输出:
4
6
您还可以尝试分解出一个基类。 它可能会与mypy一起玩得更好,并且可以更轻松地不混淆您正在使用的任何类。
class MyClassBase():
def __init__(self, constant):
self.constant=constant
def forward(self, *args, **kwargs):
raise NotImplementedError('use a derived class')
class MyClassA(MyClassBase):
def __init__(self, constant):
super().__init__(constant)
def forward(self,x1):
return x1 + self.constant
class MyClassB(MyClassBase):
def __init__(self, constant):
super().__init__(constant)
def forward(self, x1, x2):
return x1*x2 + self.constant
a = MyClassA(2)
b = MyClassB(2)
print(a.forward(2))
print(b.forward(2,2))
我想说的是,如果您要在类方法中跨返回使用常量变量名称,我们可以通过这种方式定义它来调整逻辑。 (我们可以用不同的方式来做到这一点,只使用 Params)希望这看起来不错。
class MyClass():
def __init__(self, constant, function):
self.constant=constant
self.function=function
def forward(self, x1 = None, x2 = None):
if self.function=='A':
return x1+self.constant
elif self.function=='B':
return x1*x2+self.constant
else:
print('please provide the correct function')
model1 = MyClass(2, 'A')
model1.forward(2)
model2 = MyClass(2, 'B')
model2.forward(2, 2)
在这种情况下,我也会去继承; 给定你想要的:
model2 = MyClass(2, 'B')
这样做会更容易(并且对其他人更具可读性):
model2 = MyClassB(2)
鉴于此,类似于@Nathan Chappell 在他的回答中提供的内容,但较短(例如,无需重新定义__init__
):
import torch
class Base(torch.nn.Module):
def __init__(self, constant):
super().__init__()
self.constant = constant
class MyClassA(Base):
def forward(self, x1):
return x1 + self.constant
class MyClassB(Base):
def forward(self, x1, x2):
return x1 * x2 + self.constant
您应该使用torch.nn.Module
的__call__
方法而不是forward
因为它可以与钩子一起正常工作(请参阅此答案),因此它应该是:
model1 = MyClassA(5)
model1(torch.randn(10, 5))
代替:
model1 = MyClassA(5)
model1.forward(torch.randn(10, 5))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.