I'm trying to refactor some classes using SOLID principles in python, and I have a question about how to mix SOLID with python typing.
Suppose I have these classes:
from asyncpg import Pool
class PGQuery:
async def execute(self, connection: Pool):
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery):
return await query.execute(self._connection)
and
from pydantic import BaseModel, parse_obj_as
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
result = await connection.fetchrow(...)
return parse_obj_as(QualitySummary, result)
Example usage:
pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))
The problem is that the inferred type for result
is Any
, which is the one of the base class PGQuery
. I would like (maybe with generics?) that each subclass of PGQuery
which implements its own execute
method with its own return type is then correctly inferred by the execute
method of the PGQueryExecutor
, so the return value of:
pgqueryexecutor.execute(PGQueryAnySubclass(...))
is exactly the return type of PGQueryAnySubclass.execute
.
How can I achieve that?
You could achieve this by using generic protocol , from which you can inherit specific query classes. An example is provided below. Take into account that type variable is used covariantly as a return type, therefore we marked it as covariant=True
(more details on the link above).
from abc import abstractmethod
from typing import TypeVar, Protocol
from pydantic import BaseModel, parse_obj_as
T = TypeVar('T', covariant=True)
class Pool:
...
class PGQuery(Protocol[T]):
@abstractmethod
async def execute(self, connection: Pool) -> T:
raise NotImplementedError
class PGQueryExecutor:
def __init__(self, connection: Pool):
self._connection = connection
async def execute(self, query: PGQuery[T]) -> T:
return await query.execute(self._connection)
class QualitySummary(BaseModel):
count: int
score: float
class PGQueryQualitySummary(PGQuery[QualitySummary]):
def __init__(self, node: str):
self.node = node
async def execute(self, connection: Pool) -> QualitySummary:
# ...
return parse_obj_as(QualitySummary, {})
async def main() -> None:
q = PGQueryQualitySummary("node")
ex = PGQueryExecutor(Pool())
reveal_type(await ex.execute(q)) # revealed type QualitySummary
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.