简体   繁体   中英

Type annotation for parsing a string in a JSON file to a range using Pydantic in Python

I've set up a Pydantic class that's intended to parse JSON files. The range attribute is parsed from a string of the form "11-34" (or more precisely from the regex shown):

    RANGE_STRING_REGEX = r"^(?P<first>[1-6]+)(-(?P<last>[1-6]+))?$"

    class RandomTableEvent(BaseModel):
        name: str
        range: Annotated[str, Field(regex=RANGE_STRING_REGEX)]
    
        @validator("range", allow_reuse=True)
        def convert_range_string_to_range(cls, r) -> "range":
            match_groups = re.fullmatch(RANGE_STRING_REGEX, r).groupdict()
            first = int(match_groups["first"])
            last = int(match_groups["last"]) if match_groups["last"] else first
            return range(first, last + 1)

The generated schema works and the validation passes.

However, the type annotation for the range attribute in the class is strictly speaking not correct, as the range attribute is converted from a string (type annotation) to a range object in the validator function.

What would be the correct way of annotating this and still maintaining the schema generation? Is there another way of dealing with this implicit type conversion (eg strings are automatically converted to int in Pydantic - is there something similar for custom types)?

range is not a supported type by pydantic and using it as a type for a field will cause an error when trying to create a JSON schema, but pydantic supports Custom Data Types :

You can also define your own custom data types. There are several ways to achieve it.

Classes with get_validators

You use a custom class with a classmethod __get_validators__ . It will be called to get validators to parse and validate the input data.

But this custom data type cannot inherit from range because it is final. So you could create a custom data type that uses a range internally and exposes the range methods: it will work like a range but it will not be a range ( isinstance(..., range) will be False ).

The same pydantic documentation shows how to use a __modify_schema__ method to customize the JSON schema of a custom data type.

Full example:

import re
from typing import Any, Callable, Dict, Iterator, SupportsIndex, Union

from pydantic import BaseModel


class Range:
    _RANGE_STRING_REGEX = r"^(?P<first>[1-6]+)(-(?P<last>[1-6]+))?$"

    @classmethod
    def __get_validators__(cls) -> Iterator[Callable[[Any], Any]]:
        yield cls.validate

    @classmethod
    def validate(cls, v: Any) -> "Range":
        if not isinstance(v, str):
            raise ValueError("expected string")

        match = re.fullmatch(cls._RANGE_STRING_REGEX, v)
        if not match:
            raise ValueError("invalid string")

        match_groups = match.groupdict()
        first = int(match_groups["first"])
        last = int(match_groups["last"]) if match_groups["last"] else first

        return cls(range(first, last + 1))

    def __init__(self, r: range) -> None:
        self._range = r

    @classmethod
    def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
        # Customize the JSON schema as you want
        field_schema["pattern"] = cls._RANGE_STRING_REGEX
        field_schema["type"] = "string"

    # Implement the range methods and use self._range

    @property
    def start(self) -> int:
        return self._range.start

    @property
    def stop(self) -> int:
        return self._range.stop

    @property
    def step(self) -> int:
        return self._range.step

    def count(self, value: int) -> int:
        return self._range.count(value)

    def index(self, value: int) -> int:
        return self._range.index(value)

    def __len__(self) -> int:
        return self._range.__len__()

    def __contains__(self, o: object) -> bool:
        return self._range.__contains__(o)

    def __iter__(self) -> Iterator[int]:
        return self._range.__iter__()

    def __getitem__(self, key: Union[SupportsIndex, slice]) -> int:
        return self._range.__getitem__(key)

    def __reversed__(self) -> Iterator[int]:
        return self._range.__reversed__()

    def __repr__(self) -> str:
        return self._range.__repr__()


class RandomTableEvent(BaseModel):
    name: str
    range: Range


event = RandomTableEvent(name="foo", range="11-34")

print("event:", event)
print("event.range:", event.range)
print("schema:", event.schema_json(indent=2))
print("is instance of range:", isinstance(event.range, range))
print("event.range.start:", event.range.start)
print("event.range.stop:", event.range.stop)
print("event.range[0:5]", event.range[0:5])
print("last 3 elements:", list(event.range[-3:]))

Output:

event: name='foo' range=range(11, 35)
event.range: range(11, 35)
schema: {
  "title": "RandomTableEvent",
  "type": "object",
  "properties": {
    "name": {
      "title": "Name",
      "type": "string"
    },
    "range": {
      "title": "Range",
      "pattern": "^(?P<first>[1-6]+)(-(?P<last>[1-6]+))?$",
      "type": "string"
    }
  },
  "required": [
    "name",
    "range"
  ]
}
is instance of range: False
event.range.start: 11
event.range.stop: 35
event.range[0:5] range(11, 16)
last 3 elements: [32, 33, 34]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM