簡體   English   中英

轉換為 numpy 中的索引數組

[英]Convert to indices array in numpy

類似於retrun_inverse中的numpy.unique

如果我有一個 numpy 數組 a: [['a' 'b'] ['b' 'c'] ['c' 'c'] ['c' 'b']]

我想將數組 b: [['b' 'c'] ['a' 'b'] ['c' 'c'] ['a' 'b'] ['c' 'c']]轉換為[1 0 2 0 2]

有沒有聰明的方法來轉換它?

也許使用普通的list更容易做到這一點(您可以使用.tolist()方法從 NumPy arrays 獲得):

a = [['a', 'b'], ['b', 'c'], ['c', 'c'], ['c', 'b']]
b = [['b', 'c'], ['a', 'b'], ['c', 'c'], ['a', 'b'], ['c', 'c']]

print([a.index(x) for x in b])
# [1, 0, 2, 0, 2]

或者,將其寫為 function 並假設 NumPy arrays 輸入和輸出並處理針不在大海撈針的情況:

import numpy as np


def find_by_list(haystack, needles):
    haystack = haystack.tolist()
    result = []
    for needle in needles.tolist():
        try:
            result.append(haystack.index(needle))
        except ValueError:
            result.append(-1)
    return np.array(result)

這大約與基於np.where()的 NumPy 感知解決方案一樣快(假設 np.all np.all() ) 操作的減少可以在第一個軸上完成),例如:

import numpy as np


def find_by_np(haystack, needles, haystack_axis=-1, needles_axis=-1, keepdims=False):
    if haystack_axis:
        haystack = haystack.swapaxes(0, haystack_axis)
    if needles_axis:
        needles = needles.swapaxes(0, needles_axis)
    n = needles.shape[0]
    m = haystack.ndim - 1
    shape = haystack.shape[1:]
    result = np.full((m,) + needles.shape[1:], -1)
    haystack = haystack.reshape(n, -1)
    needles = needles.reshape(n, -1)
    _, match, index = np.nonzero(np.all(
        haystack[:, None, :] == needles[:, :, None],
        axis=0, keepdims=True))
    result.reshape(m, -1)[:, match] = np.unravel_index(index, shape)
    if not keepdims and result.shape[0] == 1:
        result = np.squeeze(result, 0)
    return result

但兩者都比使用 Numba JIT 加速的簡單循環慢,例如:

import numpy as np
import numba as nb


def find_by_loop(haystack, needles):
    n, m = haystack.shape
    l, m_ = needles.shape
    result = np.full(l, -1)
    if m != m_:
        return result
    for i in range(l):
        for j in range(n):
            is_equal = True
            for k in range(m):
                if haystack[j, k] != needles[i, k]:
                    is_equal = False
                    break
            if is_equal:
                break
        if is_equal:
            result[i] = j
    return result


find_by_nb = nb.jit(find_by_loop)
find_by_nb.__name__ = 'find_by_nb'

雖然它們都給出相同的結果:

funcs = find_by_list, find_by_np, find_by_loop, find_by_nb


a = np.array([['a', 'b'], ['b', 'c'], ['c', 'c'], ['c', 'b']])
b = np.array([['b', 'c'], ['a', 'b'], ['c', 'c'], ['a', 'b'], ['c', 'c']])
print(a.shape, b.shape)
for func in funcs:
    print(f'{func.__name__:>12s}(a, b) = {func(a, b)}')
# find_by_list(a, b) = [1 0 2 0 2]
#   find_by_np(a, b) = [1 0 2 0 2]
# find_by_loop(a, b) = [1 0 2 0 2]
#   find_by_nb(a, b) = [1 0 2 0 2]

時間安排如下:

print(f'({"n":<4s}, {"m":<4s}, {"k":<4s})', end='  ')
for func in funcs:
    print(f'{func.__name__:>15s}', end='    ')
print()
for n, m, k in itertools.product((5, 50, 500), repeat=3):
    a = np.random.randint(0, 100, (k, n))
    b = np.random.randint(0, 100, (m, n))
    print(f'({n:<4d}, {m:<4d}, {k:<4d})', end='  ')
    for func in funcs:
        result = %timeit -n3 -r10 -q -o func(a, b)
        print(f'{result.best * 1e3:12.3f} ms', end='    ')
    print()
# (n   , m   , k   )     find_by_list         find_by_np       find_by_loop         find_by_nb    
# (5   , 5   , 5   )         0.008 ms           0.048 ms           0.021 ms           0.001 ms    
# (5   , 5   , 50  )         0.018 ms           0.031 ms           0.176 ms           0.001 ms    
# (5   , 5   , 500 )         0.132 ms           0.092 ms           1.754 ms           0.006 ms    
# (5   , 50  , 5   )         0.065 ms           0.031 ms           0.184 ms           0.001 ms    
# (5   , 50  , 50  )         0.139 ms           0.093 ms           1.756 ms           0.006 ms    
# (5   , 50  , 500 )         1.096 ms           0.684 ms          17.546 ms           0.049 ms    
# (5   , 500 , 5   )         0.658 ms           0.093 ms           1.871 ms           0.006 ms    
# (5   , 500 , 50  )         1.383 ms           0.699 ms          17.504 ms           0.051 ms    
# (5   , 500 , 500 )         9.102 ms           7.752 ms         177.754 ms           0.491 ms    
# (50  , 5   , 5   )         0.026 ms           0.061 ms           0.022 ms           0.001 ms    
# (50  , 5   , 50  )         0.054 ms           0.042 ms           0.174 ms           0.002 ms    
# (50  , 5   , 500 )         0.356 ms           0.203 ms           1.759 ms           0.006 ms    
# (50  , 50  , 5   )         0.232 ms           0.042 ms           0.185 ms           0.001 ms    
# (50  , 50  , 50  )         0.331 ms           0.205 ms           1.744 ms           0.006 ms    
# (50  , 50  , 500 )         1.332 ms           2.422 ms          17.492 ms           0.051 ms    
# (50  , 500 , 5   )         2.328 ms           0.197 ms           1.882 ms           0.006 ms    
# (50  , 500 , 50  )         3.092 ms           2.405 ms          17.618 ms           0.052 ms    
# (50  , 500 , 500 )        11.088 ms          18.989 ms         175.568 ms           0.479 ms    
# (500 , 5   , 5   )         0.205 ms           0.035 ms           0.023 ms           0.001 ms    
# (500 , 5   , 50  )         0.410 ms           0.137 ms           0.187 ms           0.001 ms    
# (500 , 5   , 500 )         2.800 ms           1.914 ms           1.894 ms           0.006 ms    
# (500 , 50  , 5   )         1.868 ms           0.138 ms           0.201 ms           0.001 ms    
# (500 , 50  , 50  )         2.154 ms           1.814 ms           1.902 ms           0.006 ms    
# (500 , 50  , 500 )         6.352 ms          16.343 ms          19.108 ms           0.050 ms    
# (500 , 500 , 5   )        19.798 ms           1.957 ms           2.020 ms           0.006 ms    
# (500 , 500 , 50  )        20.922 ms          13.571 ms          18.850 ms           0.052 ms    
# (500 , 500 , 500 )        35.947 ms         139.923 ms         189.747 ms           0.481 ms    

表明 Numba 提供了最快(並且 memory 效率最高)的解決方案,而其非 JIT 加速版本提供了最慢的解決方案。 基於 NumPy 的一種和基於list的一種以不同的速度出現在兩者之間。 但是對於較大的輸入,基於list的輸入平均應該更快,因為它提供了更好的短路。

不是最優雅的解決方案,但它有效:

設置(將來,顯示代碼以生成您的示例,它將使其更快地回答):

import numpy as np
a = np.array([['a', 'b'], ['b', 'c'], ['c', 'c'], ['c', 'b']])
b = np.array([['b', 'c'], ['a', 'b'], ['c', 'c'], ['a', 'b'], ['c', 'c']])
desired_output = [1, 0, 2, 0, 2]

Using thenumpy.where function (as in this related question: Is there a NumPy function to return the first index of something in an array? )

我們對每一行中的每個項目使用np.where ,將 boolean 結果相乘,然后使用列表推導逐行傳遞:

output = [np.where((x[0]==a[:,0]) * (x[1]==a[:,1]))[0][0] for x in b]

它會返回您想要的結果。

也許是一種有趣的做事方式?

a.append(None)
aa = np.array(a)[:-1]                # Note 1

b.append(None)
bb = np.array(b)[:-1]

ind_arr = bb[:, None] == aa          # Note 2
np.nonzero(ind_arr)[1]

注 1 :第一步更像是獲取object類型一維數組的開銷。 否則, numpy強制使用二維str類型的數組,這對這個應用程序沒有幫助。 這個答案中閱讀更多相關信息。 它還說明了一些替代方案。

注意 2 :這將創建一個二維 boolean 掩碼,其中aa的每個元素與bb的每個元素進行比較以獲得相等性,如下所示: ind_arr[i, j] = (bb[i] == aa[j])
下一行使用此掩碼並沿軸 1提取True值(比較已評估為True )。 這是因為比較掩碼中的aa值沿軸 1。
另一個討論以更好地理解這一點。

但是,如果您正在尋找速度,對於lists ,norok2 的答案要快得多。 這或許,可以有創新的應用。 干杯!

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM