[英]How can I functionally update a dataclass?
I'd like to copy an instance of a frozen dataclass, changing just one field ("functional update").我想复制一个冻结数据类的实例,只更改一个字段(“功能更新”)。
Here's what I tried这是我尝试过的
from dataclasses import dataclass, asdict
@dataclass(frozen = True)
class Pos:
start: int
end: int
def adjust_start(pos: Pos, delta: int) -> Pos:
# TypeError: type object got multiple values for keyword argument 'start'
return Pos(**asdict(pos), start = pos.start + delta)
adjust_start(Pos(1, 2), 4)
What I'm looking for:我在找什么:
dict
s?有没有比转换为/从dict
更直接的方法?TypeError
: if there is a way to functionally update kwargs then that could work.如何绕过TypeError
:如果有一种方法可以在功能上更新 kwargs,那么它可以工作。 In Scala, a functional update of a case class (Scala dataclass) can be done like this: pos.copy(start = pos.start + delta)
.在 Scala 中,案例 class(Scala 数据类)的功能更新可以这样完成: pos.copy(start = pos.start + delta)
。
dataclasses.replace()
to the rescue. dataclasses.replace()
进行救援。
dataclasses.replace(obj, /, **changes)
creates a new object of the same type asobj
, replacing fields with values fromchanges
.dataclasses.replace(obj, /, **changes)
创建一个与obj
相同类型的新 object ,用来自changes
的值替换字段。
import dataclasses
@dataclasses.dataclass(frozen=True)
class Pos:
start: int
end: int
def adjust_start(pos: Pos, delta: int) -> Pos:
return dataclasses.replace(pos, start=pos.start + delta)
p = adjust_start(Pos(1, 2), 4)
Personally, I might put adjust
on the dataclass itself:就个人而言,我可能会对数据类本身进行adjust
:
import dataclasses
@dataclasses.dataclass(frozen=True)
class Pos:
start: int
end: int
def adjust(self, *, start: int, end: int) -> "Pos":
return dataclasses.replace(
self,
start=self.start + start,
end=self.end + end,
)
p = Pos(1, 2).adjust(start=4)
You can use a @cached_class_property
approach with dataclasses.fields()
.您可以将@cached_class_property
方法与 dataclasses.fields dataclasses.fields()
一起使用。
For example:例如:
from dataclasses import fields, dataclass
class cached_class_property(object):
"""
Descriptor decorator implementing a class-level, read-only property,
which caches the attribute on-demand on the first use.
Credits: https://stackoverflow.com/a/4037979/10237506
"""
def __init__(self, func):
self.__func__ = func
self.__attr_name__ = func.__name__
def __get__(self, instance, cls=None):
"""This method is only called the first time, to cache the value."""
if cls is None:
cls = type(instance)
# Build the attribute.
attr = self.__func__(cls)
# Cache the value; hide ourselves.
setattr(cls, self.__attr_name__, attr)
return attr
@dataclass(frozen=True)
class Pos:
start: int
end: int
@cached_class_property
def init_fields(cls):
return tuple(f.name for f in fields(cls) if f.init)
def adjust_start(self, delta: int) -> 'Pos':
attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
for f in Pos.init_fields]
return Pos(*attrs)
p1 = Pos(1, 2)
print(p1)
p2 = Pos(1, 2).adjust_start(4)
print(p2)
As you are using a frozen=True
dataclass and with slots=False
, you could also simplify this approach, ie without the use of @cached_class_property
:当您使用 freezed frozen=True
数据类和slots=False
时,您还可以简化这种方法,即不使用@cached_class_property
:
def adjust_start(self, delta: int) -> 'Pos':
_dict = self.__dict__.copy()
_dict['start'] += delta
return Pos(*_dict.values())
Output: Output:
Pos(start=1, end=2)
Pos(start=5, end=2)
Results show it's ever slightly faster than dataclasses.replace()
:结果表明它比dataclasses.replace()
稍微快一点:
from dataclasses import fields, dataclass, replace
from timeit import timeit
@dataclass(frozen=True)
class Pos:
start: int
end: int
# `cached_class_property` defined from above
@cached_class_property
def init_fields(cls):
return tuple(f.name for f in fields(cls) if f.init)
def adjust_via_copy(self, delta: int) -> 'Pos':
_dict = self.__dict__.copy()
_dict['start'] += delta
return Pos(*_dict.values())
def adjust_via_fields(self, delta: int) -> 'Pos':
attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
for f in Pos.init_fields]
return Pos(*attrs)
def adjust_via_replace(self, delta: int) -> 'Pos':
return replace(
self,
start=self.start + delta,
)
p = Pos(1, 2)
print('o.__dict__.copy: ', round(timeit('p.adjust_via_copy(4)', globals=globals()), 3))
print('dataclasses.fields: ', round(timeit('p.adjust_via_fields(4)', globals=globals()), 3))
print('dataclasses.replace: ', round(timeit('p.adjust_via_replace(4)', globals=globals()), 3))
assert Pos(-2, 2) == p.adjust_via_replace(-3) == p.adjust_via_fields(-3) == p.adjust_via_replace(-3)
o.__dict__.copy: 0.408
dataclasses.fields: 0.499
dataclasses.replace: 0.659
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.