簡體   English   中英

按其中某些元素的頻率過濾 numpy 數組

[英]Filtering a numpy array by frequencies of certain elements in it

我有一個 numpy 數組和一個類似於下面的字典:

arr1 = np.array([['a1','x'],['a2','x'],['a3','y'],['a4','y'],['a5','z']])
d = {'x':2,'z':1,'y':1,'w':2}

對於d每個鍵值對(k,v)k在其第二列的arr1中應該恰好出現v次。 顯然,這不會發生在這里。

所以我想要做的是,從arr1 ,我想創建另一個數組,其中第二列中的每個元素都准確地出現根據d應該出現的次數。 換句話說,我想要的結果是:

np.array([['a1','x'],['a2','x'],['a5','z']])

我可以使用列表理解獲得我想要的結果:

ans = [[x1,x2] for x1,x2 in arr1 if np.count_nonzero(arr1==x2)==d[x2]]

但我想知道是否可以僅使用 numpy 來做到這一點。

這做你想要的:

import numpy as np

arr1 = np.array([['a1', 'x'], ['a2', 'x'], ['a3', 'y'], ['a4', 'y'], ['a5', 'z']])
d = {'x': 2, 'z': 1, 'y': 1, 'w': 2}

# get the actual counts of values in arr1
counts = dict(zip(*np.unique(arr1[:, 1], return_counts=True)))
# determine what values to keep, as their count matches the desired count
keep = [x for x in d if x in counts and d[x] == counts[x]]
# filter down the array
result = arr1[list(map(lambda x: x[1] in keep, arr1))]

在 numpy 中很可能有一種更優化的方法來做到這一點,但我不知道你申請的集合有多大,或者你需要多久這樣做一次,以說尋找它是否值得。

編輯:請注意,您需要擴大規模以決定什么是好的解決方案。 您的原始解決方案非常適合玩具示例,它的表現優於這兩個答案。 但是,如果您擴展到可能更現實的工作負載,@NewbieAF 提供的 numpy 解決方案可以輕松擊敗其他解決方案:

from random import randint
from timeit import timeit
import numpy as np


def original(arr1, d):
    return [[x1, x2] for x1, x2 in arr1 if np.count_nonzero(arr1 == x2) == d[x2]]


def f1(arr1, d):
    # get the actual counts of values in arr1
    counts = dict(zip(*np.unique(arr1[:, 1], return_counts=True)))
    # determine what values to keep, as their count matches the desired count
    keep = [x for x in d if x in counts and d[x] == counts[x]]
    # filter down the array
    return arr1[list(map(lambda x: x[1] in keep, arr1))]


def f2(arr1, d):
    # create arrays from d
    keys, vals = np.array(list(d.keys())), np.array(list(d.values()))
    # count the unique elements in arr1[:,1]
    unqs, cts = np.unique(arr1[:,1], return_counts=True)

    # only keep track of elements that appear in arr1
    mask = np.isin(keys,unqs)
    keys, vals = keys[mask], vals[mask]

    # sort the unique values and corresponding counts according to keys
    idx1 = np.argsort(np.argsort(keys))
    idx2 = np.argsort(unqs)
    unqs, cts = unqs[idx2][idx1], cts[idx2][idx1]

    # filter values by whether the counts match
    correct = unqs[vals==cts]

    return arr1[np.isin(arr1[:,1],correct)]


def main():
    arr1 = np.array([['a1', 'x'], ['a2', 'x'], ['a3', 'y'], ['a4', 'y'], ['a5', 'z']])
    d = {'x': 2, 'z': 1, 'y': 1, 'w': 2}

    print(timeit(lambda: original(arr1, d), number=10000))
    print(timeit(lambda: f1(arr1, d), number=10000))
    print(timeit(lambda: f2(arr1, d), number=10000))

    counts = [randint(1, 3) for _ in range(10000)]
    arr1 = np.array([['x', f'{n}'] for n in range(10000) for _ in range(counts[n])])
    d = {f'{n}': randint(1, 3) for n in range(10000)}

    print(timeit(lambda: original(arr1, d), number=10))
    print(timeit(lambda: f1(arr1, d), number=10))
    print(timeit(lambda: f2(arr1, d), number=10))

main()

結果:

0.14045359999999998
0.2402685
0.5027185999999999
46.7569239
5.893172499999999
0.08729539999999503

numpy解決方案在玩具示例上很慢,但在大輸入上要快numpy數量級。 您的解決方案看起來不錯,但是在擴展時輸給了非 numpy 解決方案,避免了額外的調用。

考慮問題的大小。 如果問題很小,您應該選擇自己的解決方案,以提高可讀性。 如果問題是中等規模的,您可能會選擇我的來提高性能。 如果問題很大(無論是大小還是使用頻率),您應該選擇全 numpy 解決方案,犧牲可讀性來提高速度。

np.argsort() ,我找到了一個純粹的 numpy 解決方案。 只需要根據相同元素在d.values()的數組版本中的位置對arr1的第二行進行排序。

arr1 = np.array([['a1','x'],['a2','x'],['a3','y'],['a4','y'],['a5','z']])
d = {'x':2,'z':1,'y':1,'w':2}

# create arrays from d
keys, vals = np.array(list(d.keys())), np.array(list(d.values()))
# count the unique elements in arr1[:,1]
unqs, cts = np.unique(arr1[:,1], return_counts=True)

# only keep track of elements that appear in arr1
mask = np.isin(keys,unqs)
keys, vals = keys[mask], vals[mask]

# sort the unique values and corresponding counts according to keys
idx1 = np.argsort(np.argsort(keys))
idx2 = np.argsort(unqs)
unqs, cts = unqs[idx2][idx1], cts[idx2][idx1]

# filter values by whether the counts match
correct = unqs[vals==cts]

# keep subarray where the counts match
ans = arr1[np.isin(arr1[:,1],correct)]

print(ans)
# [['a1' 'x']
#  ['a2' 'x']
#  ['a5' 'z']]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM