[英]How to overwrite attributes in OOP inheritance with Python dataclass
目前,我有一些代码看起来像这样,删除了不相关的方法。
import math
import numpy as np
from decimal import Decimal
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class A:
S0: int
K: int
r: float = 0.05
T: int = 1
N: int = 2
StockTrees: List[float] = field(init=False, default_factory=list)
pu: Optional[float] = 0
pd: Optional[float] = 0
div: Optional[float] = 0
sigma: Optional[float] = 0
is_put: Optional[bool] = field(default=False)
is_american: Optional[bool] = field(default=False)
is_call: Optional[bool] = field(init=False)
is_european: Optional[bool] = field(init=False)
def __post_init__(self):
self.is_call = not self.is_put
self.is_european = not self.is_american
@property
def dt(self):
return self.T/float(self.N)
@property
def df(self):
return math.exp(-(self.r - self.div) * self.dt)
@dataclass
class B(A):
u: float = field(init=False)
d: float = field(init=False)
qu: float = field(init=False)
qd: float = field(init=False)
def __post_init__(self):
super().__post_init__()
self.u = 1 + self.pu
self.d = 1 - self.pd
self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
self.qd = 1 - self.qu
@dataclass
class C(B):
def __post_init__(self):
super().__post_init__()
self.u = math.exp(self.sigma * math.sqrt(self.dt))
self.d = 1/self.u
self.qu = (math.exp((self.r - self.div)*self.dt) - self.d)/(self.u - self.d)
self.qd = 1 - self.qu
基本上,我有一个 class A
,它定义了一些其所有子类将共享的属性,因此它只是真正意味着通过其子类的实例化进行初始化,其属性将由其子类继承。 子进程 class B
是一个执行某些计算的进程,该进程由C
继承,后者执行相同计算的变体。 C
基本上继承了B
的所有方法,唯一不同的是self.u
和self.d
的计算不同。
可以通过使用需要 arguments pu
和pd
的B
计算或需要参数sigma
的C
计算来运行代码,如下所示
if __name__ == "__main__":
am_option = B(50, 52, r=0.05, T=2, N=2, pu=0.2, pd=0.2, is_put=True, is_american=True)
print(f"{am_option.sigma = }")
print(f"{am_option.pu = }")
print(f"{am_option.pd = }")
print(f"{am_option.qu = }")
print(f"{am_option.qd = }")
eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
print(f"{am_option.sigma = }")
print(f"{am_option.pu = }")
print(f"{am_option.pd = }")
print(f"{am_option.qu = }")
print(f"{am_option.qd = }")
这给出了 output
am_option.pu = 0.2
am_option.pd = 0.2
am_option.qu = 0.6281777409400603
am_option.qd = 0.3718222590599397
Traceback (most recent call last):
File "/home/dazza/option_pricer/test.py", line 136, in <module>
eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
File "<string>", line 15, in __init__
File "/home/dazza/option_pricer/test.py", line 109, in __post_init__
super().__post_init__()
File "/home/dazza/option_pricer/test.py", line 55, in __post_init__
self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
ZeroDivisionError: float division by zero
因此实例化B
工作正常,因为它成功计算了值pu
、 pd
、 qu
和qd
。 但是,当C
的实例化无法计算qu
时,我的问题就来了,因为默认情况下pu
和pd
为零,使其除以 0。
我的问题:如何解决这个问题,以便C
继承 A 的所有属性初始化(包括__post_init__
)和A
所有方法,同时计算B
self.u = math.exp(self.sigma * math.sqrt(self.dt))
和self.d = 1/self.u
覆盖self.u = 1 + self.pu
和self.d = 1 - self.pd
B
的 self.pd,同时保留self.qu
和self.qd
相同吗?( B
和C
相同)
定义另一个方法来初始化u
和d
,这样您就可以覆盖B
的那部分而不用覆盖qu
和qd
的定义方式。
@dataclass
class B(A):
u: float = field(init=False)
d: float = field(init=False)
qu: float = field(init=False)
qd: float = field(init=False)
def __post_init__(self):
super().__post_init__()
self._define_u_and_d()
self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
self.qd = 1 - self.qu
def _define_u_and_d(self):
self.u = 1 + self.pu
self.d = 1 - self.pd
@dataclass
class C(B):
def _define_u_and_d(self):
self.u = math.exp(self.sigma * math.sqrt(self.dt))
self.d = 1/self.u
Python 支持多个 inheritance。您可以先从A
继承B
,这意味着任何重叠的方法都将从A
中获取(例如__post_init__
)。 您在 class C
中编写的任何代码都将覆盖从A
和B
继承的内容。 如果您需要更好地控制哪些方法来自哪个 class,您始终可以在C
中定义该方法并调用 function 调用A
或B
(如A.dt(self)
)。
class C(A, B):
...
另一个编辑:我刚刚看到A
在C
中初始化了一些你想要的东西。 因为C
的父级现在是A
(如果您使用我上面的代码),您可以在super().__post_init__()
行中添加回C
的__post_init__
以便它调用A
的__post_init__
。 如果这不起作用,您总是可以将A.__post_init__(self)
放在__post_init__
的C
中。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.