[英]Counting number of occurrences of an array in array of numpy 2D arrays
我有一个 arrays 的 numpy 二维数组:
samples = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
我需要计算一个数组在数组内部出现的次数,例如:
counts = [[1,2,3]:2, [2,3,4]:3, [4,5,6]:1]
我不确定如何按上面的方式计算或列出它以了解哪些数组和计数相互连接,感谢您的帮助。 谢谢!
您需要的一切都直接在numpy
中:
import numpy as np
a = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
print(np.unique(a, axis=0, return_counts=True))
结果:
(array([[1, 2, 3],
[2, 3, 4],
[4, 5, 6]]), array([2, 3, 1], dtype=int64))
结果是一个包含唯一行的数组的元组,以及一个包含这些行计数的数组。
如果您需要成对通过它们 go :
unique_rows, counts = np.unique(a, axis=0, return_counts=True)
for row, c in zip(unique_rows, counts):
print(row, c)
结果:
[1 2 3] 2
[2 3 4] 3
[4 5 6] 1
这是一种不使用大部分 numpy 库的方法:
import numpy as np
samples = np.array([[1,2,3], [2,3,4], [4,5,6], [1,2,3], [2,3,4], [2,3,4]])
result = {}
for row in samples:
inDictionary = False
for check in range(len(result)):
if np.all(result[str(check)][0] == row):
result[str(check)][1]+= 1
inDictionary = True
else:
pass
if inDictionary == False:
result[str(len(result))] = [row, 1]
print("------------------")
print(result)
此方法创建一个名为 result 的字典,然后遍历样本中的各种嵌套列表并检查它们是否已经在字典中。 如果它们是它出现的次数,则加 1。否则,它会为该数组创建一个新条目。 现在已经保存的计数和值可以使用result["index"]
访问您想要的索引和result["index"][0]
- 对于数组值 & result["index"][1]
-它出现的次数。
与其他 Python (没有numpy
)解决方案相比,有一种相对快速的 Python 方法:
from collections import Counter
>>> Counter(map(tuple, samples.tolist())) # convert to dict if you need it
Counter({(1, 2, 3): 2, (2, 3, 4): 3, (4, 5, 6): 1})
Python 也做得很快,因为元组索引的操作被优化得很好
import benchit
%matplotlib inline
benchit.setparams(rep=3)
sizes = [3, 10, 30, 100, 300, 900, 3000, 9000, 30000, 90000, 300000, 900000, 3000000]
arr = np.random.randint(0,10, size=(sizes[-1], 3)).astype(int)
def count_python(samples):
return Counter(map(tuple, samples.tolist()))
def count_numpy(samples):
return np.unique(samples, axis=0, return_counts=True)
fns = [count_python, count_numpy]
in_ = {s: (arr[:s],) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Number of items')
t.plot(logx=True, figsize=(12, 6), fontsize=14)
注意arr.tolist()
消耗了大约 0.8sec/3M 的 Python 计算时间。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.