繁体   English   中英

Python numpy:有效获取包含其他3列的每个唯一元组的列的最小值的行

[英]Python numpy: Efficiently get rows containing min value of column for each unique tuple of 3 other columns

我有一些数据存储在列表列表中(约200,000行x 6列的列表)。

我需要获取以下数据子集:对于列[1,2,4]中的每个唯一值集,我需要找到具有列0最小值的行并仅保留该行。

我必须在旧的numpy 1.10中执行此操作(不要问...),因此np.unique()中没有“ axis = 0”选项。

下面的示例运行并产生正确的输出,但是非常慢。 这似乎很基础,所以我觉得(缺乏)速度一定是我的错。

# S-L-O-W way to get the desired output:
import numpy as np

# Example dataset
data = [[1, 1, 1, 'a', 1],
        [0, 1, 1, 'b', 1],
        [0, 3, 1, 'c', 4],
        [3, 1, 1, 'd', 1],
        [4, 3, 1, 'e', 4]]

desired_output = [[0, 1, 1, 'b', 1],
                  [0, 3, 1, 'c', 4]]

# Currently coding on a geriatric machine with numpy pre-version 1.13 and no ability to upgrade,
# so np.unique() won't take an axis argument. The next few hack lines of code get around this with strings...
tuples_str = []
tuples_raw = [[datarow[jj] for jj in [1,2,4]]  for datarow in data ]
for datarow in data:
    one_tuple = [datarow[jj] for jj in [1,2,4]]
    tuples_str.append( '_'.join([str(ww) for ww in one_tuple]) )

# Numpy unique on this data subset with just columns [1,2,4] of original data
unq, unq_inv, unq_cnt = np.unique(tuples_str, return_inverse=True, return_counts=True)

# Storage
output = []

# Here's the painfully slow part:
# Iterate over each subset of data where rows take the value in one unique tuple (i.e. columns [1,2,4] are identical)
for ii, idx in enumerate(np.unique(unq_inv)):

    # Get the rows that have the same values in columns [1,2,4]
    all_matches_thistuple = [row for ii, row in enumerate(data) if unq_inv[ii]==idx]

    # Find the index of the row with the minimum value for column 0
    first_line_min_idx = np.argmin([int(row1[0]) for row1 in all_matches_thistuple])

    # Save only that row
    output.append(all_matches_thistuple[first_line_min_idx])
print(output)

如果您是列表列表开始的 ,则可以使用普通的Python轻松完成此操作,它将很有效。 确实,您正在将numpy与dtype object一起使用,所以我怀疑您使用内置例程会获得很少的性能,因为您丢失了数据局部性(并且实际上留下了of脚的Python list对象)。 相反,您可以在线性时间内完成此操作(不计算初始数据类型,该数据将为O(n * logN),但它将使用Python的timsort,因此实际上将非常快),只需执行几次传递数据:

In [1]: data = [[1, 1, 1, 'a', 1],
   ...:         [0, 1, 1, 'b', 1],
   ...:         [0, 3, 1, 'c', 4],
   ...:         [3, 1, 1, 'd', 1],
   ...:         [4, 3, 1, 'e', 4]]
   ...:

In [2]: from operator import itemgetter

In [3]: group_key = itemgetter(1,2,4)

In [4]: data.sort(key=group_key)

然后简单地:

In [6]: first = itemgetter(0)

In [7]: result = []

In [8]: from itertools import groupby
   ...: for _, g in groupby(data, group_key):
   ...:     result.append(min(g, key=first))
   ...:

In [9]: result
Out[9]: [[0, 1, 1, 'b', 1], [0, 3, 1, 'c', 4]]

另一种方法是使用defaultdict构建辅助数据结构。 这是对未排序数据进行分组的惯用方式。 如果您希望能够将这些值分组,这可能会很有用:

In [10]: from collections import defaultdict

In [11]: grouper = defaultdict(list)

In [12]: data = [[1, 1, 1, 'a', 1],
    ...:         [0, 1, 1, 'b', 1],
    ...:         [0, 3, 1, 'c', 4],
    ...:         [3, 1, 1, 'd', 1],
    ...:         [4, 3, 1, 'e', 4]]

In [13]: for row in data:
    ...:     _,x,y,_, z = row
    ...:     grouper[(x,y,z)].append(row)
    ...:

In [14]: grouper
Out[14]:
defaultdict(list,
            {(1, 1, 1): [[1, 1, 1, 'a', 1],
              [0, 1, 1, 'b', 1],
              [3, 1, 1, 'd', 1]],
             (3, 1, 4): [[0, 3, 1, 'c', 4], [4, 3, 1, 'e', 4]]})

In [15]: first = itemgetter(0)

In [16]: [min(group, key=first) for group in grouper.values()]
Out[16]: [[0, 1, 1, 'b', 1], [0, 3, 1, 'c', 4]]

如果您可以使用Pandas,这是一种方法:

df = pd.DataFrame(data).sort_values(0).drop_duplicates([1, 2, 4]).values

结果

[[0 1 1 'b' 1]
 [0 3 1 'c' 4]]

说明

您的问题可以简化为:

  • 按列0排序,默认为ascending=True
  • 按列[1、2、4]删除重复的行。
  • pd.DataFrame.values提取基础的numpy数组。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM