简体   繁体   中英

How to make a class generic and correctly infer the return type of its method

I'm trying to refactor some classes using SOLID principles in python, and I have a question about how to mix SOLID with python typing.

Suppose I have these classes:

from asyncpg import Pool


class PGQuery:
    async def execute(self, connection: Pool):
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery):
        return await query.execute(self._connection)

and

from pydantic import BaseModel, parse_obj_as


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        result = await connection.fetchrow(...)

        return parse_obj_as(QualitySummary, result)

Example usage:

pgqueryexecutor = PGQueryExecutor(...)
result = await pgqueryexecutor.execute(PGQueryQualitySummary(...))

The problem is that the inferred type for result is Any , which is the one of the base class PGQuery . I would like (maybe with generics?) that each subclass of PGQuery which implements its own execute method with its own return type is then correctly inferred by the execute method of the PGQueryExecutor , so the return value of:

pgqueryexecutor.execute(PGQueryAnySubclass(...))

is exactly the return type of PGQueryAnySubclass.execute .

How can I achieve that?

You could achieve this by using generic protocol , from which you can inherit specific query classes. An example is provided below. Take into account that type variable is used covariantly as a return type, therefore we marked it as covariant=True (more details on the link above).

from abc import abstractmethod
from typing import TypeVar, Protocol

from pydantic import BaseModel, parse_obj_as

T = TypeVar('T', covariant=True)


class Pool:
    ...


class PGQuery(Protocol[T]):
    @abstractmethod
    async def execute(self, connection: Pool) -> T:
        raise NotImplementedError


class PGQueryExecutor:
    def __init__(self, connection: Pool):
        self._connection = connection

    async def execute(self, query: PGQuery[T]) -> T:
        return await query.execute(self._connection)


class QualitySummary(BaseModel):
    count: int
    score: float


class PGQueryQualitySummary(PGQuery[QualitySummary]):
    def __init__(self, node: str):
        self.node = node

    async def execute(self, connection: Pool) -> QualitySummary:
        # ...
        return parse_obj_as(QualitySummary, {})


async def main() -> None:
    q = PGQueryQualitySummary("node")
    ex = PGQueryExecutor(Pool())
    reveal_type(await ex.execute(q))  # revealed type QualitySummary


The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM