繁体   English   中英

如何使类通用并正确推断其方法的返回类型

[英]How to make a class generic and correctly infer the return type of its method

我正在尝试使用 python 中的 SOLID 原则重构一些类,我有一个关于如何将 SOLID 与 python 类型混合的问题。

假设我有这些课程:

from asyncpg import Pool


class PGQuery:
    async def execute(self, connection: Pool):
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery):
        return await query.execute(self._connection)

from pydantic import BaseModel, parse_obj_as


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        result = await connection.fetchrow(...)

        return parse_obj_as(QualitySummary, result)

用法示例:

pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))

问题是result的推断类型是Any ,它是基类PGQuery 我希望(也许使用泛型?)然后通过PGQueryExecutorexecute方法正确推断PGQuery每个子类,它实现了自己的execute方法和自己的返回类型,因此返回值是:

pgqueryexecutor.execute(PGQueryAnySubclass(...))

正是PGQueryAnySubclass.execute的返回类型。

我怎样才能做到这一点?

您可以通过使用通用协议来实现这一点,您可以从中继承特定的查询类。 下面提供了一个示例。 考虑到该类型变量被协变用作返回类型,因此我们将其标记为covariant=True (有关上面链接的更多详细信息)。

from abc import abstractmethod
from typing import TypeVar, Protocol

from pydantic import BaseModel, parse_obj_as

T = TypeVar('T', covariant=True)


class Pool:
    ...


class PGQuery(Protocol[T]):
    @abstractmethod
    async def execute(self, connection: Pool) -> T:
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery[T]) -> T:
        return await query.execute(self._connection)


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery[QualitySummary]):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        # ...
        return parse_obj_as(QualitySummary, {})


async def main() -> None:
    q = PGQueryQualitySummary("node")
    ex = PGQueryExecutor(Pool())
    reveal_type(await ex.execute(q))  # revealed type QualitySummary


暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM