[英]Discriminated union in Python
假设我有一个基类 class 和两个派生类。 我还有一个工厂方法,它返回其中一个类的 object。 问题是,mypy 或 IntelliJ 无法确定 object 是哪种类型。 他们知道两者都可以,但不知道究竟是哪一个。 有什么方法可以帮助 mypy/IntelliJ 解决这个问题,而无需在conn
变量名称旁边放置类型提示?
import abc
import enum
import typing
class BaseConnection(abc.ABC):
@abc.abstractmethod
def sql(self, query: str) -> typing.List[typing.Any]:
...
class PostgresConnection(BaseConnection):
def sql(self, query: str) -> typing.List[typing.Any]:
return "This is a postgres result".split()
def only_postgres_things(self):
pass
class MySQLConnection(BaseConnection):
def sql(self, query: str) -> typing.List[typing.Any]:
return "This is a mysql result".split()
def only_mysql_things(self):
pass
class ConnectionType(enum.Enum):
POSTGRES = 1
MYSQL = 2
def connect(conn_type: ConnectionType) -> typing.Union[PostgresConnection, MySQLConnection]:
if conn_type is ConnectionType.POSTGRES:
return PostgresConnection()
if conn_type is ConnectionType.MYSQL:
return MySQLConnection()
conn = connect(ConnectionType.POSTGRES)
conn.only_postgres_things()
正如您所看到的两种方法:当我希望 IntelliJ/mypy 从我传递给connect
function 的类型中找出它时,建议使用only_postgres_things
和only_mysql_things
。
您可以尝试将typing.overload
与typing.Literal
结合使用,如下所示:
@typing.overload
def connect(type_: typing.Literal[ConnectionType.POSTGRES]) -> PostgresConnection:
...
@typing.overload
def connect(type_: typing.Literal[ConnectionType.MYSQL]) -> MySQLConnection:
...
def connect(type_):
if type_ is ConnectionType.POSTGRES:
return PostgresConnection()
if type_ is ConnectionType.MYSQL:
return MySQLConnection()
我用type_
替换了type
,这样你就不会隐藏内置函数,并且使用is
而不是==
来比较枚举值是惯用的。
由于您的ConnectionType
class 的目的显然是使您的 API 更具可读性和用户友好性,而不是使用Enum
的任何特定功能,因此您实际上不必将其设为Enum
class。
相反,您可以创建一个常规的 class,并将每个连接类型分配给一个用户友好名称的 class 变量,这样您就可以使用类型变量键入connect
function 的返回值,并使用该类型的类型键入参数多变的。 使用类型别名使类型变量的类型更具可读性:
class ConnectionTypes:
POSTGRES = PostgresConnection
MYSQL = MySQLConnection
Connection = typing.TypeVar('Connection', PostgresConnection, MySQLConnection)
# or make it bound to the base class:
# Connection = typing.TypeVar('Connection', bound=BaseConnection)
ConnectionType: typing.TypeAlias = type[Connection]
def connect(type_: ConnectionType) -> Connection:
if type_ is ConnectionType.POSTGRES:
return PostgresConnection()
if type_ is ConnectionType.MYSQL:
return MySQLConnection()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.