I have the following dataclass Gear
that I want to limit the maximum value for gear_level from 0 to 5. But as you can see when I increment gear_level, it goes higher than 5, which is not what I want. I tried method as well as postinit . How do I fix this problem?
from dataclasses import dataclass
@dataclass
class Gear:
gear_level: int = 0
direction: str = None
# more codes ...
def __postinit__(self):
if self.gear_level <= 0:
self.gear_level = 0
elif 5 > self.gear_level > 0:
self.gear_level = self.gear_level
else:
self.gear_level = 5
def set_gear_level(self, level):
if level <= 0:
self.gear_level = 0
elif 5 > level > 0:
self.gear_level = level
else:
self.gear_level = 5
g = Gear()
g.set_gear_level(6)
print(g)
g.gear_level += 1
print(g)
g.set_gear_level(-1)
print(g)
g.gear_level -= 1
print(g)
Ideally, I prefer to use the g.gear_level += 1 notation, because I want to increment gear_level. It should not jump from gear level 1 to 5. Also, when it decrement, it should stop at 0. It should take both an assignment of 0 and be allowed to decrement to 0. Can this be done?
Gear(gear_level=5, direction=None)
Gear(gear_level=6, direction=None)
Gear(gear_level=0, direction=None)
Gear(gear_level=-1, direction=None)
The suggested link in the comments provides an elegant solution for tackling this issue, eg using a custom descriptor class which should work with minimal changes needed on your end.
For example, here's how I'd define a BoundsValidator
descriptor class to check that a class attribute is within an expected lower and upper bounds (note that either bounds are optional in this case):
from typing import Optional
try:
from typing import get_args
except ImportError: # Python 3.7
from typing_extensions import get_args
class BoundsValidator:
"""Descriptor to validate an attribute x remains within a specified bounds.
That is, checks the constraint `low <= x <= high` is satisfied. Note that
both low and high are optional. If none are provided, no bounds will be
applied.
"""
__slots__ = ('name',
'type',
'validator')
def __init__(self, min_val: Optional[int] = None,
max_val: Optional[int] = float('inf')):
if max_val is None: # only minimum
def validator(name, val):
if val < min_val:
raise ValueError(f"values for {name!r} have to be > {min_val}; got {val!r}")
elif min_val is None: # only maximum
def validator(name, val):
if val > max_val:
raise ValueError(f"values for {name!r} have to be < {max_val}; got {val!r}")
else: # both upper and lower bounds are given
def validator(name, val):
if not min_val <= val <= max_val:
raise ValueError(f"values for {name!r} have to be within the range "
f"[{min_val}, {max_val}]; got {val!r}")
self.validator = validator
def __set_name__(self, owner, name):
# save the attribute name on an initial run
self.name = name
# set the valid types based on the annotation for the attribute
# for example, `int` or `Union[int, float]`
tp = owner.__annotations__[name]
self.type = get_args(tp) or tp
def __get__(self, instance, owner):
if not instance:
return self
return instance.__dict__[self.name]
def __delete__(self, instance):
del instance.__dict__[self.name]
def __set__(self, instance, value):
# can be removed if you don't need the type validation
if not isinstance(value, self.type):
raise TypeError(f"{self.name!r} values must be of type {self.type!r}")
# validate that the value is within expected bounds
self.validator(self.name, value)
# finally, set the value on the instance
instance.__dict__[self.name] = value
Finally, here's the sample code I came up with to test that it's working as we'd expect:
from dataclasses import dataclass
from typing import Union
@dataclass
class Person:
age: int = BoundsValidator(1) # let's assume a person must at least be 1 years
num: Union[int, float] = BoundsValidator(-1, 1)
gear_level: int = BoundsValidator(0, 5)
def main():
p = Person(10, 0.7, 5)
print(p)
# should raise a ValueError now
try:
p.gear_level += 1
except ValueError as e:
print(e)
# and likewise here, for the lower bound
try:
p.gear_level -= 7
except ValueError as e:
print(e)
# all these should now raise an error
try:
_ = Person(0, 0, 2)
except ValueError as e:
print(e)
try:
_ = Person(120, -3.1, 2)
except ValueError as e:
print(e)
if __name__ == '__main__':
main()
This provides the output below when we run the code:
Person(age=10, num=0.7, gear_level=5)
values for 'gear_level' have to be within the range [0, 5]; got 6
values for 'gear_level' have to be within the range [0, 5]; got -2
values for 'age' have to be within the range [1, inf]; got 0
values for 'num' have to be within the range [-1, 1]; got -3.1
In this case I would simply use a property:
@dataclass
class Gear:
gear_level: int
# Rest of the class excluded for simplicity
@property
def gear_level(self) -> int:
return self._gear_level
@gear_level.setter
def gear_level(self, value: int) -> None:
self._gear_level = min(max(value, 0), 5)
This way you don't need to write a __post_init__
or have to remember to call specific methods: assignment to gear_level
will be kept 0 <= gear_level <= 5
, even with +=
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.