![](/img/trans.png)
[英]Python typing, mypy infer return type based on class method return type
[英]How to make a class generic and correctly infer the return type of its method
我正在嘗試使用 python 中的 SOLID 原則重構一些類,我有一個關於如何將 SOLID 與 python 類型混合的問題。
假設我有這些課程:
from asyncpg import Pool
class PGQuery:
async def execute(self, connection: Pool):
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery):
return await query.execute(self._connection)
和
from pydantic import BaseModel, parse_obj_as
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
result = await connection.fetchrow(...)
return parse_obj_as(QualitySummary, result)
用法示例:
pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))
問題是result
的推斷類型是Any
,它是基類PGQuery
。 我希望(也許使用泛型?)然后通過PGQueryExecutor
的execute
方法正確推斷PGQuery
每個子類,它實現了自己的execute
方法和自己的返回類型,因此返回值是:
pgqueryexecutor.execute(PGQueryAnySubclass(...))
正是PGQueryAnySubclass.execute
的返回類型。
我怎樣才能做到這一點?
您可以通過使用通用協議來實現這一點,您可以從中繼承特定的查詢類。 下面提供了一個示例。 考慮到該類型變量被協變用作返回類型,因此我們將其標記為covariant=True
(有關上面鏈接的更多詳細信息)。
from abc import abstractmethod
from typing import TypeVar, Protocol
from pydantic import BaseModel, parse_obj_as
T = TypeVar('T', covariant=True)
class Pool:
...
class PGQuery(Protocol[T]):
@abstractmethod
async def execute(self, connection: Pool) -> T:
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery[T]) -> T:
return await query.execute(self._connection)
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery[QualitySummary]):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
# ...
return parse_obj_as(QualitySummary, {})
async def main() -> None:
q = PGQueryQualitySummary("node")
ex = PGQueryExecutor(Pool())
reveal_type(await ex.execute(q)) # revealed type QualitySummary
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.