简体   繁体   English

有没有办法在 NumPy 中向量化指向彼此的两个数组?

[英]Is there a way to vectorize two arrays pointing at each other in NumPy?

For all the folks who rock at vectorizing loops: I have two NumPy arrays of shape (N,) that contain indices to each other.对于所有喜欢矢量化循环的人:我有两个形状为(N,) NumPy 数组(N,)它们包含彼此的索引。 Say we have a = np.asarray([0, 1, 2]) and b = np.array([1, 2, np.nan]) .假设我们有a = np.asarray([0, 1, 2])b = np.array([1, 2, np.nan]) The function should first look at a[0] to get 0 , then do b[0] to get 1 , then again a[1] to get 2 , and so on until we get np.nan .该函数应该首先查看a[0]得到0 ,然后执行b[0]得到1 ,然后再次a[1]得到2 ,依此类推,直到我们得到np.nan So the function is simply a[b[a[b[a[0]]]]] = np.nan .所以函数只是a[b[a[b[a[0]]]]] = np.nan The output should contain two lists of values that were called for a and b respectively.输出应包含分别为ab调用的两个值列表。 Indices in b are always greater than in a , such that the process cannot get stuck.在指数b总是比更大的a ,使得该方法不能被卡住。

I wrote a simple function that can do just this (wrapped with numba - 18.2 µs):我写了一个简单的函数,可以做到这一点(用 numba 包裹 - 18.2 µs):

a = np.array([0, 1, 2, 3, 4])
b = np.array([ 2.,  3.,  4., nan, nan])

lst = []
while True:
    if len(lst) > 0:
        idx = lst[-1]
    else:
        idx = 0
    if len(lst) % 2 == 0:
        if idx < len(a) - 1:
            next_idx = a[idx]
            if np.isnan(next_idx):
                break
            lst.append(int(next_idx))
        else:
            break
    else:
        if idx < len(b) - 1:
            next_idx = b[idx]
            if np.isnan(next_idx):
                break
            lst.append(int(next_idx))
        else:
            break

The first list is lst[::2] :第一个列表是lst[::2]

[0, 2]

The second is lst[1::2] :第二个是lst[1::2]

[2, 4]

Any way to vectorize this?有什么方法可以矢量化吗? Both arrays in inputs as well as both lists in output always have the same shape.输入中的两个数组以及输出中的两个列表始终具有相同的形状。

This is not a vectorized solution, but as a Numba solution it should be quite faster, and simpler.这不是矢量化解决方案,但作为 Numba 解决方案,它应该更快、更简单。 I changed the code slightly to use integers and -1 instead of np.nan , it is trivial to switch to this representation with something like b = np.where(np.isnan(b), -1, b) , and it makes the code more efficient.我稍微更改了代码以使用整数和-1而不是np.nan ,使用类似b = np.where(np.isnan(b), -1, b)类的东西切换到这种表示是微不足道的,它使代码效率更高。 Instead of having a growing structure within the Numba function, I preallocate the output array in advance, so the loop can run much faster.我没有在 Numba 函数中使用不断增长的结构,而是提前预分配了输出数组,因此循环可以运行得更快。

import numba as nb

def point_each_other(a, b):
    # Convert inputs to array if they are not already
    a = np.asarray(a)
    b = np.asarray(b)
    # Make output array in advance
    out = np.empty(len(a) + len(b), dtype=a.dtype)
    # Call Numba function
    n = point_each_other_nb(a, b, out)
    # Return relevan part of the output
    return out[:n]

@nb.njit
def point_each_other_nb(a, b, out):
    curr = 0
    i = 0
    while curr >= 0:
        # You can do bad input checking with the following
        # if i >= len(out):
        #     raise ValueError
        # Save current index
        out[i] = curr
        # Get the next index
        curr = a[curr]
        # Swap arrays
        a, b = b, a
        # Advance counter
        i += 1
    # Return number of stored indices
    return i - 1

# Test
a = np.array([0, 1, 2, 3, 4])
b = np.array([2, 3, 4, -1, -1])
out = point_each_other(a, b)
print(out[::2])
# [0 2 4]
print(out[1::2])
# [0 2]

Not vectorized, but here's recursive solution:未矢量化,但这里是递归解决方案:

import numpy as np

a = np.array([0,1,2,3,4])
b = np.array([2,3,4,np.nan, np.nan])

def rec(i,b,a, a_out, b_out):
    if np.isnan(b[i]): return
    else:
        if not np.isnan(b[i]): a_out.append(i)
        rec(int(b[i]), a, b, b_out, a_out)
    return a_out, b_out

print(rec(0,b,a,[],[]))

Output输出

([0, 2], [2, 4])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM