简体   繁体   English

如何创建一个接受 numpy 数组、可迭代对象或标量的 numpy 函数?

[英]how can I make a numpy function that accepts a numpy array, an iterable, or a scalar?

Suppose I have this:假设我有这个:

def incrementElements(x):
   return x+1

but I want to modify it so that it can take either a numpy array, an iterable, or a scalar, and promote the argument to a numpy array and add 1 to each element.但我想修改它,以便它可以采用 numpy 数组、可迭代或标量,并将参数提升为 numpy 数组并为每个元素添加 1。

How could I do that?我怎么能那样做? I suppose I could test argument class but that seems like a bad idea.我想我可以测试参数类,但这似乎是个坏主意。 If I do this:如果我这样做:

def incrementElements(x):
   return numpy.array(x)+1

it works properly on arrays or iterables but not scalars.它适用于数组或可迭代对象,但不适用于标量。 The problem here is that numpy.array(x) for scalar x produces some weird object that is contained by a numpy array but isn't a "real" array;这里的问题是标量 x 的numpy.array(x)产生了一些奇怪的对象,它包含在一个 numpy 数组中,但不是“真正的”数组; if I add a scalar to it, the result is demoted to a scalar.如果我向它添加标量,结果将降级为标量。

You could try你可以试试

def incrementElements(x):
    x = np.asarray(x)
    return x+1

np.asarray(x) is the equivalent of np.array(x, copy=False) , meaning that a scalar or an iterable will be transformed to a ndarray , but if x is already a ndarray , its data will not be copied. np.asarray(x)等效于np.array(x, copy=False) ,这意味着标量或可迭代对象将转换为ndarray ,但如果x已经是ndarray ,则不会复制其数据。

If you pass a scalar and want a ndarray as output (not a scalar), you can use:如果您传递标量并希望将ndarray作为输出(不是标量),则可以使用:

def incrementElements(x):
    x = np.array(x, copy=False, ndmin=1)
    return x

The ndmin=1 argument will force the array to have at least one dimension. ndmin=1参数将强制数组至少具有一维。 Use ndmin=2 for at least 2 dimensions, and so forth.对至少 2 个维度使用ndmin=2 ,依此类推。 You can also use its equivalent np.atleast_1d (or np.atleast_2d for the 2D version...)您还可以使用其等效的np.atleast_1d (或np.atleast_2d用于 2D 版本...)

Pierre GM's answer is great so long as your function exclusively uses ufuncs (or something similar) to implicitly loop over the input values.只要您的函数专门使用 ufuncs(或类似的东西)来隐式循环输入值,Pierre GM 的回答就很好。 If your function needs to iterate over the inputs, then np.asarray doesn't do enough, because you can't iterate over a NumPy scalar:如果您的函数需要迭代输入,那么np.asarray还不够,因为您无法迭代 NumPy 标量:

import numpy as np

x = np.asarray(1)
for xval in x:
    print(np.exp(xval))

Traceback (most recent call last):
  File "Untitled 2.py", line 4, in <module>
    for xval in x:
TypeError: iteration over a 0-d array

If your function needs to iterate over the input, something like the following will work, using np.atleast_1d and np.squeeze (see Array manipulation routines — NumPy Manual ).如果您的函数需要迭代输入,则使用np.atleast_1dnp.squeeze (请参阅数组操作例程 — NumPy 手册),如下所示。 I included an aaout ("Always Array OUT") arg so you can specify whether you want scalar inputs to produce single-element array outputs;我包含了一个aaout (“Always Array OUT”)参数,因此您可以指定是否希望标量输入生成单元素数组输出; it could be dropped if not needed:如果不需要,它可以被删除:

def scalar_or_iter_in(x, aaout=False):
    """
    Gather function evaluations over scalar or iterable `x` values.

    aaout :: boolean
        "Always array output" flag:  If True, scalar input produces
        a 1-D, single-element array output.  If false, scalar input
        produces scalar output.
    """
    x = np.asarray(x)
    scalar_in = x.ndim==0

    # Could use np.array instead of np.atleast_1d, as follows:
    # xvals = np.array(x, copy=False, ndmin=1)
    xvals = np.atleast_1d(x)
    y = np.empty_like(xvals, dtype=float)  # dtype in case input is ints
    for i, xx in enumerate(xvals):
        y[i] = np.exp(xx)  # YOUR OPERATIONS HERE!

    if scalar_in and not aaout:
        return np.squeeze(y)
    else:
        return y


print(scalar_or_iter_in(1.))
print(scalar_or_iter_in(1., aaout=True))
print(scalar_or_iter_in([1,2,3]))


2.718281828459045
[2.71828183]
[ 2.71828183  7.3890561  20.08553692]

Of course, for exponentiation you should not explicitly iterate as here, but a more complex operation may not be expressible using NumPy ufuncs.当然,对于求幂,您不应该像这里那样显式迭代,但是使用 NumPy ufunc 可能无法表达更复杂的操作。 If you do not need to iterate, but want similar control over whether scalar inputs produce single-element array outputs, the middle of the function could be simpler, but the return has to handle the np.atleast_1d :如果您不需要迭代,但要在标量输入是否会产生单元素数组输出类似的控制,该函数中可能很简单,但来回必须处理np.atleast_1d

def scalar_or_iter_in(x, aaout=False):
    """
    Gather function evaluations over scalar or iterable `x` values.

    aaout :: boolean
        "Always array output" flag:  If True, scalar input produces
        a 1-D, single-element array output.  If false, scalar input
        produces scalar output.
    """
    x = np.asarray(x)
    scalar_in = x.ndim==0

    y = np.exp(x)  # YOUR OPERATIONS HERE!

    if scalar_in and not aaout:
        return np.squeeze(y)
    else:
        return np.atleast_1d(y)

I suspect in most cases the aaout flag is not necessary, and that you'd always want scalar outputs with scalar inputs.我怀疑在大多数情况下aaout标志不是必需的,并且您总是想要带有标量输入的标量输出。 In such cases, the return should just be:在这种情况下,回报应该是:

    if scalar_in:
        return np.squeeze(y)
    else:
        return y

This is an old question, but here are my two cents.这是一个老问题,但这是我的两分钱。

Although Pierre GM's answer works great, it has the---maybe undesirable---side effect of converting scalars to arrays.尽管 Pierre GM 的回答效果很好,但它具有将标量转换为数组的---可能是不受欢迎的---副作用。 If that is what you want/need, then stop reading;如果那是您想要/需要的,请停止阅读; otherwise, carry on.否则,继续。 While this might be okay (and is probably good for lists and other iterables to return a np.array ), it could be argued that for scalars it should return a scalar.虽然这可能没问题(并且可能有利于lists和其他iterables返回np.array ),但可以争论的是,对于标量它应该返回一个标量。 If that's the desired behaviour, why not follow python's EAFP philosophy.如果这是所需的行为,为什么不遵循 python 的EAFP哲学。 This is what I usually do (I changed the example to show what could happen when np.asarray returns a "scalar"):这就是我通常所做的(我更改了示例以显示当np.asarray返回“标量”时会发生什么):

def saturateElements(x):
    x = np.asarray(x)
    try:
        x[x>1] = 1
    except TypeError:
        x = min(x,1)
    return x

I realize it is more verbose than Pierre GM's answer, but as I said, this solution will return a scalar if a scalar is passed, or a np.array is an array or iterable is passed.我意识到它比 Pierre GM 的答案更冗长,但正如我所说,如果传递了标量,或者np.array是数组或可迭代的,则此解决方案将返回标量。

Existing solutions all work.现有的解决方案都有效。 Here's yet another option.这是另一种选择。

def incrementElements(x):
    try:
        iter(x)
    except TypeError:
        x = [x]
    x = np.array(x)
    # code that operators on x

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM