简体   繁体   English

如何在功能上更新数据类?

[英]How can I functionally update a dataclass?

I'd like to copy an instance of a frozen dataclass, changing just one field ("functional update").我想复制一个冻结数据类的实例,只更改一个字段(“功能更新”)。

Here's what I tried这是我尝试过的

from dataclasses import dataclass, asdict    
@dataclass(frozen = True)    
class Pos:    
    start: int    
    end: int    
     
def adjust_start(pos: Pos, delta: int) -> Pos:    
   # TypeError: type object got multiple values for keyword argument 'start'    
   return Pos(**asdict(pos), start = pos.start + delta)    
     
adjust_start(Pos(1, 2), 4)   

What I'm looking for:我在找什么:

  • Is there a more straightforward way than converting to/from dict s?有没有比转换为/从dict更直接的方法?
  • How to get around the TypeError : if there is a way to functionally update kwargs then that could work.如何绕过TypeError :如果有一种方法可以在功能上更新 kwargs,那么它可以工作。

In Scala, a functional update of a case class (Scala dataclass) can be done like this: pos.copy(start = pos.start + delta) .在 Scala 中,案例 class(Scala 数据类)的功能更新可以这样完成: pos.copy(start = pos.start + delta)

dataclasses.replace() to the rescue. dataclasses.replace()进行救援。

dataclasses.replace(obj, /, **changes) creates a new object of the same type as obj , replacing fields with values from changes . dataclasses.replace(obj, /, **changes)创建一个与obj相同类型的新 object ,用来自changes的值替换字段。

import dataclasses


@dataclasses.dataclass(frozen=True)
class Pos:
    start: int
    end: int


def adjust_start(pos: Pos, delta: int) -> Pos:
    return dataclasses.replace(pos, start=pos.start + delta)


p = adjust_start(Pos(1, 2), 4)

Personally, I might put adjust on the dataclass itself:就个人而言,我可能会对数据类本身进行adjust

import dataclasses


@dataclasses.dataclass(frozen=True)
class Pos:
    start: int
    end: int

    def adjust(self, *, start: int, end: int) -> "Pos":
        return dataclasses.replace(
            self,
            start=self.start + start,
            end=self.end + end,
        )


p = Pos(1, 2).adjust(start=4)

You can use a @cached_class_property approach with dataclasses.fields() .您可以将@cached_class_property方法与 dataclasses.fields dataclasses.fields()一起使用。

For example:例如:

from dataclasses import fields, dataclass


class cached_class_property(object):
    """
    Descriptor decorator implementing a class-level, read-only property,
    which caches the attribute on-demand on the first use.

    Credits: https://stackoverflow.com/a/4037979/10237506
    """
    def __init__(self, func):
        self.__func__ = func
        self.__attr_name__ = func.__name__

    def __get__(self, instance, cls=None):
        """This method is only called the first time, to cache the value."""
        if cls is None:
            cls = type(instance)

        # Build the attribute.
        attr = self.__func__(cls)

        # Cache the value; hide ourselves.
        setattr(cls, self.__attr_name__, attr)

        return attr


@dataclass(frozen=True)
class Pos:
    start: int
    end: int

    @cached_class_property
    def init_fields(cls):
        return tuple(f.name for f in fields(cls) if f.init)

    def adjust_start(self, delta: int) -> 'Pos':
        attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
                 for f in Pos.init_fields]

        return Pos(*attrs)


p1 = Pos(1, 2)
print(p1)

p2 = Pos(1, 2).adjust_start(4)
print(p2)

As you are using a frozen=True dataclass and with slots=False , you could also simplify this approach, ie without the use of @cached_class_property :当您使用 freezed frozen=True数据类和slots=False时,您还可以简化这种方法,即不使用@cached_class_property

def adjust_start(self, delta: int) -> 'Pos':
    _dict = self.__dict__.copy()
    _dict['start'] += delta

    return Pos(*_dict.values())

Output: Output:

Pos(start=1, end=2)
Pos(start=5, end=2)

Results show it's ever slightly faster than dataclasses.replace() :结果表明它比dataclasses.replace()稍微快一点:

from dataclasses import fields, dataclass, replace
from timeit import timeit


@dataclass(frozen=True)
class Pos:
    start: int
    end: int

    # `cached_class_property` defined from above
    @cached_class_property
    def init_fields(cls):
        return tuple(f.name for f in fields(cls) if f.init)

    def adjust_via_copy(self, delta: int) -> 'Pos':
        _dict = self.__dict__.copy()
        _dict['start'] += delta

        return Pos(*_dict.values())

    def adjust_via_fields(self, delta: int) -> 'Pos':
        attrs = [getattr(self, f) + delta if f == 'start' else getattr(self, f)
                 for f in Pos.init_fields]

        return Pos(*attrs)

    def adjust_via_replace(self, delta: int) -> 'Pos':
        return replace(
            self,
            start=self.start + delta,
        )


p = Pos(1, 2)

print('o.__dict__.copy:      ', round(timeit('p.adjust_via_copy(4)', globals=globals()), 3))
print('dataclasses.fields:   ', round(timeit('p.adjust_via_fields(4)', globals=globals()), 3))
print('dataclasses.replace:  ', round(timeit('p.adjust_via_replace(4)', globals=globals()), 3))

assert Pos(-2, 2) == p.adjust_via_replace(-3) == p.adjust_via_fields(-3) == p.adjust_via_replace(-3)
o.__dict__.copy:       0.408
dataclasses.fields:    0.499
dataclasses.replace:   0.659

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM