![](/img/trans.png)
[英]Python typing, mypy infer return type based on class method return type
[英]How to make a class generic and correctly infer the return type of its method
我正在尝试使用 python 中的 SOLID 原则重构一些类,我有一个关于如何将 SOLID 与 python 类型混合的问题。
假设我有这些课程:
from asyncpg import Pool
class PGQuery:
async def execute(self, connection: Pool):
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery):
return await query.execute(self._connection)
和
from pydantic import BaseModel, parse_obj_as
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
result = await connection.fetchrow(...)
return parse_obj_as(QualitySummary, result)
用法示例:
pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))
问题是result
的推断类型是Any
,它是基类PGQuery
。 我希望(也许使用泛型?)然后通过PGQueryExecutor
的execute
方法正确推断PGQuery
每个子类,它实现了自己的execute
方法和自己的返回类型,因此返回值是:
pgqueryexecutor.execute(PGQueryAnySubclass(...))
正是PGQueryAnySubclass.execute
的返回类型。
我怎样才能做到这一点?
您可以通过使用通用协议来实现这一点,您可以从中继承特定的查询类。 下面提供了一个示例。 考虑到该类型变量被协变用作返回类型,因此我们将其标记为covariant=True
(有关上面链接的更多详细信息)。
from abc import abstractmethod
from typing import TypeVar, Protocol
from pydantic import BaseModel, parse_obj_as
T = TypeVar('T', covariant=True)
class Pool:
...
class PGQuery(Protocol[T]):
@abstractmethod
async def execute(self, connection: Pool) -> T:
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery[T]) -> T:
return await query.execute(self._connection)
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery[QualitySummary]):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
# ...
return parse_obj_as(QualitySummary, {})
async def main() -> None:
q = PGQueryQualitySummary("node")
ex = PGQueryExecutor(Pool())
reveal_type(await ex.execute(q)) # revealed type QualitySummary
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.