简体   繁体   English

数据类的子类,带有一些断言

[英]Subclass of dataclass, with some assertions

I have a frozen dataclass MyData that holds data.我有一个保存数据的冻结数据类MyData I would like a distinguished subclass MySpecialData can only hold data of length 1. Here is a working implementation.我想要一个杰出的子类MySpecialData只能保存长度为 1 的数据。这是一个有效的实现。

from dataclasses import dataclass, field


@dataclass(frozen=True)
class MyData:
    id: int = field()
    data: list[float] = field()

    def __len__(self) -> int:
        return len(self.data)


@dataclass(frozen=True)
class MySpecialData(MyData):
    def __post_init__(self):
        assert len(self) == 1


# correctly throws exception
special_data = MySpecialData(id=1, data=[2, 3])

I spent some time messing with __new__ and __init__ , but couldn't reach a working solution.我花了一些时间弄乱__new____init__ ,但无法找到可行的解决方案。 The code works, but I am a novice and am soliciting the opinion of someone experienced if this is the "right" way to accomplish this.该代码有效,但我是一个新手,如果这是完成此任务的“正确”方式,我正在征求有经验的人的意见。 Any critiques or suggestions on how to do this better or more correctly would be appreciated.任何关于如何更好或更正确地做到这一点的批评或建议将不胜感激。

For examples not using dataclasses , I imagine the correct way would be overriding __new__ in the subclass.对于不使用dataclasses的示例,我想正确的方法是在子类中覆盖__new__ I suspect my attempts at overriding __new__ fail here because of the special way dataclasses works.我怀疑由于dataclasses的特殊工作方式,我尝试覆盖__new__在这里失败了。 Would you agree?你会同意吗?

Thank you for your opinion.谢谢您的意见。

Don't use assert .不要使用assert Use利用

if len(self) != 1:
    raise ValueError

assert can be turned off with the -O switch ie., if you run your script like可以使用-O开关关闭assert ,即,如果您像这样运行脚本

python -O my_script.py

it will no longer raise an error.它不会再引发错误。

Another option is to use a custom user-defined list subclass, which checks the len of the list upon instantiation.另一种选择是使用自定义的用户定义list子类,它会在实例化时检查列表的len

from dataclasses import dataclass, field
from typing import Sequence, TypeVar, Generic

T = TypeVar('T')


class ConstrainedList(list, Generic[T]):

    def __init__(self, seq: Sequence[T] = (), desired_len: int = 1):
        super().__init__(seq)

        if len(self) != desired_len:
            raise ValueError(f'expected length {desired_len}, got {len(self)}. items={self}')


@dataclass(frozen=True)
class MyData:
    id: int = field()
    data: ConstrainedList[float] = field(default_factory=ConstrainedList)


@dataclass(frozen=True)
class MySpecialData(MyData):
    ...


# correctly throws exception
special_data = MySpecialData(id=1, data=ConstrainedList([2, 3]))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM