[英]Counting number of occurrences of an array in array of numpy 2D arrays
我有一個 arrays 的 numpy 二維數組:
samples = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
我需要計算一個數組在數組內部出現的次數,例如:
counts = [[1,2,3]:2, [2,3,4]:3, [4,5,6]:1]
我不確定如何按上面的方式計算或列出它以了解哪些數組和計數相互連接,感謝您的幫助。 謝謝!
您需要的一切都直接在numpy
中:
import numpy as np
a = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
print(np.unique(a, axis=0, return_counts=True))
結果:
(array([[1, 2, 3],
[2, 3, 4],
[4, 5, 6]]), array([2, 3, 1], dtype=int64))
結果是一個包含唯一行的數組的元組,以及一個包含這些行計數的數組。
如果您需要成對通過它們 go :
unique_rows, counts = np.unique(a, axis=0, return_counts=True)
for row, c in zip(unique_rows, counts):
print(row, c)
結果:
[1 2 3] 2
[2 3 4] 3
[4 5 6] 1
這是一種不使用大部分 numpy 庫的方法:
import numpy as np
samples = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
result = {}
for row in samples:
inDictionary = False
for check in range(len(result)):
if np.all(result[str(check)][0] == row):
result[str(check)][1]+= 1
inDictionary = True
else:
pass
if inDictionary == False:
result[str(len(result))] = [row, 1]
print("------------------")
print(result)
此方法創建一個名為 result 的字典,然后遍歷樣本中的各種嵌套列表並檢查它們是否已經在字典中。 如果它們是它出現的次數,則加 1。否則,它會為該數組創建一個新條目。 現在已經保存的計數和值可以使用result["index"]
訪問您想要的索引和result["index"][0]
- 對於數組值 & result["index"][1]
-它出現的次數。
與其他 Python (沒有numpy
)解決方案相比,有一種相對快速的 Python 方法:
from collections import Counter
>>> Counter(map(tuple, samples.tolist())) # convert to dict if you need it
Counter({(1, 2, 3): 2, (2, 3, 4): 3, (4, 5, 6): 1})
Python 也做得很快,因為元組索引的操作被優化得很好
import benchit
%matplotlib inline
benchit.setparams(rep=3)
sizes = [3, 10, 30, 100, 300, 900, 3000, 9000, 30000, 90000, 300000, 900000, 3000000]
arr = np.random.randint(0,10, size=(sizes[-1], 3)).astype(int)
def count_python(samples):
return Counter(map(tuple, samples.tolist()))
def count_numpy(samples):
return np.unique(samples, axis=0, return_counts=True)
fns = [count_python, count_numpy]
in_ = {s: (arr[:s],) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Number of items')
t.plot(logx=True, figsize=(12, 6), fontsize=14)
注意arr.tolist()
消耗了大約 0.8sec/3M 的 Python 計算時間。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.