繁体   English   中英

如何键入提示通用 numpy 数组?

[英]How to type hint a generic numpy array?

有没有办法将 Numpy 数组键入为通用数组?

我目前正在使用 Numpy 1.23.5 和 Python 3.10,但我无法键入提示以下示例。

import numpy as np
import numpy.typing as npt


E = TypeVar("E") # Should be bounded to a numpy type

def double_arr(arr: npt.NDArray[E]) -> npt.NDArray[E]:
    return arr * 2

我期待什么

arr = np.array([1, 2, 3], dtype=np.int8)
double_arr(arr) # npt.NDAarray[np.int8]

arr = np.array([1, 2.3, 3], dtype=np.float32)
double_arr(arr) # npt.NDAarray[np.float32]

但我最终遇到以下错误

arr: npt.NDArray[E]
                ^^^
Could not specialize type "NDArray[ScalarType@NDArray]"
  Type "E@double_arr" cannot be assigned to type "generic"
    "object*" is incompatible with "generic"

如果我将 E 绑定到 numpy 数据类型( np.int8, np.uint8, ... ),类型检查器由于多种数据类型而无法评估乘法。

查看源代码,似乎用于参数化numpy.dtype numpy.typing.NDArray泛型类型变量受numpy.generic限制(并声明协变)。 因此,NDArray 的任何类型参数都必须是NDArray的子numpy.generic ,而您的类型变量是无界的。 应该工作:

from typing import TypeVar

import numpy as np
from numpy.typing import NDArray


E = TypeVar("E", bound=np.generic, covariant=True)


def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return arr * 2

但是还有另一个问题,我认为是 numpy 存根不足。 本期中展示了一个示例。 __mul__这样的重载操作数(魔术)方法以某种方式破坏了类型。 我现在只是粗略地看了一下代码,所以我不知道缺少什么。 但是mypy仍然会抱怨该代码的最后一行:

error: Returning Any from function declared to return "ndarray[Any, dtype[E]]"  [no-any-return]
error: Unsupported operand types for * ("ndarray[Any, dtype[E]]" and "int")  [operator]

现在的解决方法是使用函数而不是操作数(通过 dunder 方法)。 在这种情况下,使用numpy.multiply而不是*可以解决问题:

from typing import TypeVar

import numpy as np
from numpy.typing import NDArray


E = TypeVar("E", bound=np.generic, covariant=True)


def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return np.multiply(arr, 2)


a = np.array([1, 2, 3], dtype=np.int8)
reveal_type(double_arr(a))

不再有mypy投诉,类型显示如下:

numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[numpy._typing._8Bit]]]

值得关注该操作数问题,甚至可能单独报告Unsupported operand types for *的特定错误。 我还没有在问题跟踪器中找到它。


PS :或者,您可以使用*运算符并添加特定type: ignore 这样你会注意到,如果/一旦注释错误最终被 numpy 修复,因为mypy在严格模式下抱怨未使用的忽略指令。

def double_arr(arr: NDArray[E]) -> NDArray[E]:
    return arr * 2  # type: ignore[operator,no-any-return]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM