简体   繁体   English

歧视联盟 Python

[英]Discriminated union in Python

Imagine I have a base class and two derived classes.假设我有一个基类 class 和两个派生类。 I also have a factory method, that returns an object of one of the classes.我还有一个工厂方法,它返回其中一个类的 object。 The problem is, mypy or IntelliJ can't figure out which type the object is.问题是,mypy 或 IntelliJ 无法确定 object 是哪种类型。 They know it can be both, but not which one exactly.他们知道两者都可以,但不知道究竟是哪一个。 Is there any way I can help mypy/IntelliJ to figure this out WITHOUT putting a type hint next to the conn variable name?有什么方法可以帮助 mypy/IntelliJ 解决这个问题,而无需在conn变量名称旁边放置类型提示?

import abc
import enum
import typing


class BaseConnection(abc.ABC):
    @abc.abstractmethod
    def sql(self, query: str) -> typing.List[typing.Any]:
        ...


class PostgresConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a postgres result".split()

    def only_postgres_things(self):
        pass


class MySQLConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a mysql result".split()

    def only_mysql_things(self):
        pass


class ConnectionType(enum.Enum):
    POSTGRES = 1
    MYSQL = 2


def connect(conn_type: ConnectionType) -> typing.Union[PostgresConnection, MySQLConnection]:
    if conn_type is ConnectionType.POSTGRES:
        return PostgresConnection()
    if conn_type is ConnectionType.MYSQL:
        return MySQLConnection()


conn = connect(ConnectionType.POSTGRES)
conn.only_postgres_things()

Look at how IntelliJ handles this:看看 IntelliJ 是如何处理这个的: 在此处输入图像描述

As you can see both methods: only_postgres_things and only_mysql_things are suggested when I'd like IntelliJ/mypy to figure it out out of the type I'm passing to the connect function.正如您所看到的两种方法:当我希望 IntelliJ/mypy 从我传递给connect function 的类型中找出它时,建议使用only_postgres_thingsonly_mysql_things

You could try using typing.overload combined with typing.Literal , like so:您可以尝试将typing.overloadtyping.Literal结合使用,如下所示:


@typing.overload
def connect(type_: typing.Literal[ConnectionType.POSTGRES]) -> PostgresConnection:
    ...

@typing.overload
def connect(type_: typing.Literal[ConnectionType.MYSQL]) -> MySQLConnection:
    ...

def connect(type_):
    if type_ is ConnectionType.POSTGRES:
        return PostgresConnection()
    if type_ is ConnectionType.MYSQL:
        return MySQLConnection()

I replaced type with type_ so you don't shadow the builtin, and it's idiomatic to compare enum values using is instead of == .我用type_替换了type ,这样你就不会隐藏内置函数,并且使用is而不是==来比较枚举值是惯用的。

Since the purpose of your ConnectionType class is apparently to make your API more readable and user-friendly rather than to use any specific features of Enum , you don't really have to make it an Enum class.由于您的ConnectionType class 的目的显然是使您的 API 更具可读性和用户友好性,而不是使用Enum的任何特定功能,因此您实际上不必将其设为Enum class。

Instead, you can create a regular class with each connection type assigned to a class variable of a user-friendly name, so that you can type the return value of the connect function with a type variable and type the parameter with the type of the type variable.相反,您可以创建一个常规的 class,并将每个连接类型分配给一个用户友好名称的 class 变量,这样您就可以使用类型变量键入connect function 的返回值,并使用该类型的类型键入参数多变的。 Use a type alias to make the type of the type variable even more readable:使用类型别名使类型变量的类型更具可读性:

class ConnectionTypes:
    POSTGRES = PostgresConnection
    MYSQL = MySQLConnection

Connection = typing.TypeVar('Connection', PostgresConnection, MySQLConnection)
# or make it bound to the base class:
# Connection = typing.TypeVar('Connection', bound=BaseConnection)
ConnectionType: typing.TypeAlias = type[Connection]

def connect(type_: ConnectionType) -> Connection:
    if type_ is ConnectionType.POSTGRES:
        return PostgresConnection()
    if type_ is ConnectionType.MYSQL:
        return MySQLConnection()

将常规类与类型变量结合使用。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM