简体   繁体   中英

How to limit values in dataclass attribute?

I have the following dataclass Gear that I want to limit the maximum value for gear_level from 0 to 5. But as you can see when I increment gear_level, it goes higher than 5, which is not what I want. I tried method as well as postinit . How do I fix this problem?

from dataclasses import dataclass

@dataclass
class Gear:
    gear_level: int = 0
    direction: str = None
    # more codes ...

    def __postinit__(self):
        if self.gear_level <= 0:
            self.gear_level = 0
        elif 5 > self.gear_level > 0:
            self.gear_level = self.gear_level
        else:
            self.gear_level = 5

    def set_gear_level(self, level):
        if level <= 0:
            self.gear_level = 0
        elif 5 > level > 0:
            self.gear_level = level
        else:
            self.gear_level = 5


g = Gear()

g.set_gear_level(6)

print(g)

g.gear_level += 1

print(g)

g.set_gear_level(-1)

print(g)

g.gear_level -= 1
print(g)

Ideally, I prefer to use the g.gear_level += 1 notation, because I want to increment gear_level. It should not jump from gear level 1 to 5. Also, when it decrement, it should stop at 0. It should take both an assignment of 0 and be allowed to decrement to 0. Can this be done?

Gear(gear_level=5, direction=None)
Gear(gear_level=6, direction=None)
Gear(gear_level=0, direction=None)
Gear(gear_level=-1, direction=None)

The suggested link in the comments provides an elegant solution for tackling this issue, eg using a custom descriptor class which should work with minimal changes needed on your end.

For example, here's how I'd define a BoundsValidator descriptor class to check that a class attribute is within an expected lower and upper bounds (note that either bounds are optional in this case):

from typing import Optional

try:
    from typing import get_args
except ImportError:  # Python 3.7
    from typing_extensions import get_args


class BoundsValidator:
    """Descriptor to validate an attribute x remains within a specified bounds.

    That is, checks the constraint `low <= x <= high` is satisfied. Note that
    both low and high are optional. If none are provided, no bounds will be
    applied.
    """
    __slots__ = ('name',
                 'type',
                 'validator')

    def __init__(self, min_val: Optional[int] = None,
                 max_val: Optional[int] = float('inf')):

        if max_val is None:     # only minimum
            def validator(name, val):
                if val < min_val:
                    raise ValueError(f"values for {name!r}  have to be > {min_val}; got {val!r}")

        elif min_val is None:   # only maximum
            def validator(name, val):
                if val > max_val:
                    raise ValueError(f"values for {name!r}  have to be < {max_val}; got {val!r}")

        else:                   # both upper and lower bounds are given
            def validator(name, val):
                if not min_val <= val <= max_val:
                    raise ValueError(f"values for {name!r}  have to be within the range "
                                     f"[{min_val}, {max_val}]; got {val!r}")

        self.validator = validator

    def __set_name__(self, owner, name):
        # save the attribute name on an initial run
        self.name = name

        # set the valid types based on the annotation for the attribute
        #   for example, `int` or `Union[int, float]`
        tp = owner.__annotations__[name]
        self.type = get_args(tp) or tp

    def __get__(self, instance, owner):
        if not instance:
            return self
        return instance.__dict__[self.name]

    def __delete__(self, instance):
        del instance.__dict__[self.name]

    def __set__(self, instance, value):
        # can be removed if you don't need the type validation
        if not isinstance(value, self.type):
            raise TypeError(f"{self.name!r} values must be of type {self.type!r}")

        # validate that the value is within expected bounds
        self.validator(self.name, value)

        # finally, set the value on the instance
        instance.__dict__[self.name] = value

Finally, here's the sample code I came up with to test that it's working as we'd expect:

from dataclasses import dataclass
from typing import Union


@dataclass
class Person:
    age: int = BoundsValidator(1)   # let's assume a person must at least be 1 years
    num: Union[int, float] = BoundsValidator(-1, 1)
    gear_level: int = BoundsValidator(0, 5)


def main():
    p = Person(10, 0.7, 5)
    print(p)

    # should raise a ValueError now
    try:
        p.gear_level += 1
    except ValueError as e:
        print(e)

    # and likewise here, for the lower bound
    try:
        p.gear_level -= 7
    except ValueError as e:
        print(e)

    # all these should now raise an error

    try:
        _ = Person(0, 0, 2)
    except ValueError as e:
        print(e)

    try:
        _ = Person(120, -3.1, 2)
    except ValueError as e:
        print(e)


if __name__ == '__main__':
    main()

This provides the output below when we run the code:

Person(age=10, num=0.7, gear_level=5)
values for 'gear_level'  have to be within the range [0, 5]; got 6
values for 'gear_level'  have to be within the range [0, 5]; got -2
values for 'age'  have to be within the range [1, inf]; got 0
values for 'num'  have to be within the range [-1, 1]; got -3.1

In this case I would simply use a property:

@dataclass
class Gear:
    gear_level: int

    # Rest of the class excluded for simplicity

    @property
    def gear_level(self) -> int:
        return self._gear_level

    @gear_level.setter
    def gear_level(self, value: int) -> None:
        self._gear_level = min(max(value, 0), 5)

This way you don't need to write a __post_init__ or have to remember to call specific methods: assignment to gear_level will be kept 0 <= gear_level <= 5 , even with += .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM