繁体   English   中英

numpy arrays 的逐元素比较(Python)

[英]Element-wise comparison of numpy arrays (Python)

我想问一个关于下面 numpy 数组的问题。

我有一个数据集,它有50 rows and 15 columns ,我创建了一个 numpy 数组,如下所示:

x=x.to_numpy()

我的目标是将每一行与其他行进行比较(按元素和自身除外),并找出是否有所有值都小于该行的行。

样品表:

a b c         
1 6 2
2 6 8
4 7 12
7 9 13

例如,第 1 行和第 2 行没有这样的一行。 但是第 3,4 行有一行,其中第 1 行和第 2 行的所有值都小于所有这些值。 所以算法应该返回计数 2(表示第 3 行和第 4 行)。

应该执行哪个 Python 代码来获得这个特定的回报。

我尝试了一堆代码,但无法找到合适的解决方案。 因此,如果有人对此有任何想法,我将不胜感激。

只需使用两个循环并进行比较

import numpy as np

def f(x):
    count = 0

    for i in range(x.shape[0]):
        for j in range(x.shape[0]):
            if i == j:
                continue
            if np.all(x[i] > x[j]):
                count += 1
                break

    return count

x = np.array([[1, 6, 2], [2, 6, 8], [4, 7, 12], [7, 9, 13]])
print(f(x))

编辑:纯 numpy 解决方案

(x.reshape(-1, 1, 3) > x.reshape(1, -1, 3)).all(axis=2).any(axis=1).sum()

解释

困难的部分是在 3d 中思考,所以我从 2d 开始,简单地比较数字。 假设您有x=np.array([1,2,3,4])并且您想要将 x 的所有元素与 x 的所有其他元素进行比较,从而生成一个 4x4 布尔矩阵。

您要做的是在一侧将 x 重塑为一列值,在另一侧将其重塑为一条线。 所以有两个二维数组:一个 4x1,另一个 1x4。

然后,在这两个数组之间执行操作时,广播将创建一个 4x4 数组。

只是为了形象化,而不是比较,让我们这样做

x=np.array([1,2,3,4])
x.reshape(-1,1) #is
#[[1],
# [2],
# [3],
# [4]]
x.reshape(1,-1) #is
# [ [1,2,3,4] ]
x.reshape(-1,1)*10+x.reshape(1,-1) #is therefore
# [[11, 12, 13, 14],
#  [21, 22, 23, 24],
#  [31, 32, 33, 34],
#  [41, 42, 43, 44]]

# Likewise 
x.reshape(-1,1)<x.reshape(1,-1) # is
#array([[False,  True,  True,  True],
#       [False, False,  True,  True],
#       [False, False, False,  True],
#       [False, False, False, False]])

所以,我们所要做的就是完全一样的事情。 但是值是长度为 3 的一维数组而不是标量:
x.reshape(-1, 1, 3) > x.reshape(1, -1, 3)

与前面的示例一样,广播将使它成为所有x[i]>x[j]的二维数组,除了x[i]x[j]和因此x[i]>x[j]不是值, 但 1d 长度 3 阵列。 所以我们的结果是一个长度为 3 的 1d 数组的 2d 数组,也就是 3d 数组。

现在我们只需要做我们所有的,任何,总和。 要将x[i]视为x[j] ,我们需要x[i]的所有值>x[j]的所有值。 因此, all在轴 2(长度 3 的轴)上。 现在我们有一个二维矩阵告诉每个 i,j 如果x[i]>x[j]

为了使x[j]具有较小的对应项,即x[j]大于至少一个x[i] ,我们需要在x[j]列上至少有一个 True。 因此any(axis=1)

最后,此时我们拥有的是一维布尔数组,如果它至少存在一个较小的值,则为 True。 我们只需要计算它们。 因此.sum()

复合迭代

单线(带一个环。不理想,但比 2 个环好)

sum((r>x).all(axis=1).any() for r in x)

r>x是一个布尔数组,将行r的每个元素与x的每个元素进行比较。 因此,例如,当r是行x[2]时,则r>x

array([[ True,  True,  True],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False]])

所以(r>x).all(axis=1)是一个形状(4,)的布尔值数组,告诉每行中的所有布尔值(因为.all仅遍历列, axis=1 )是否为真。 在前面的示例中,那将是[True, True, False, False] (x[1]>x).all(axis=1)将是[False, False, False, False]x[1]>x的第一行包含 2 个True ,但这对于.all来说还不够)

所以(r>x).all(axis=1).any()告诉你想知道的:是否有任何一行的所有列都是True 那就是如果前面的数组中有任何 True 。

((r>x).all(axis=1).any() for r in x)是针对 x 的所有行r的此计算的迭代器。 如果您将外部( )替换为[ , ] ,您将得到一个TrueFalse列表(False,False,True,True,准确地说,正如您已经说过的那样:第一行为 False,第二行为 True其他)。 但是不需要在这里建立一个列表,因为我们只是想数数。 复合迭代器只会在调用者需要时产生结果,在这里,调用者是sum

sum((r>x).all(axis=1).any() for r in x)计算我们在之前的计算中得到True的次数。

(在这种情况下,因为列表中只有 4 个元素,所以我并没有通过使用复合迭代器而不是复合列表来节省大量内存。但是当我们不使用复合迭代器时,尝试使用复合迭代器是一个好习惯'真的需要在内存中构建所有中间结果的列表)

时序

对于您的示例,纯 numpy 的计算需要 19 微秒,前一个答案需要 48 微秒,di.bezrukov 的需要 115 微秒。

但是差异(以及没有差异)会显示行数何时增加。 那么对于 10000×3 的数据,我的两个答案的计算都需要 3.9 秒,而 di.bezrukov 的方法需要 353 秒。

这两个事实背后的原因:

  • di.bezrukov 的差异变大的事实是因为我避免的内部 for 循环的数量变大了,而且它们很重要
  • 我的两个版本之间的差异消失的事实是因为我的第二个版本(按时间顺序,在这条消息中首先,也就是我的纯 numpy 版本)只保留了外循环。 如果行数不是那么大,那是不可忽略的。 但是当它很大时......好吧,外循环本身(不计算其内容,由内循环优化)只是 O(n),在 O(n²) 结果中。 所以,如果 n 足够大,我们就不关心这个外循环的效率如何。
  • 更糟糕的是:在内存方面,纯 numpy 版本做了我为在我的第一个版本中没有做而感到自豪的事情:计算结果的完整列表。 那没什么。 它还计算一个完整的 3d 布尔矩阵。 那只是中间结果。 因此,对于足够大的 n(例如 100000,除非您有 50Gb 的 RAM),中间结果不适合内存。 即使你有 50Gb 的 RAM,它也不会更快)

尽管如此,所有 3 种方法都是 O(n²)。 O(n²×m) 偶数,如果我们称m为列数

都有 3 个嵌套循环。 Di.bezrukov 在.all中有两个显式的 python for循环和一个隐式循环(仍然是一个 for 循环,即使它是在 numpy 的内部代码中完成的)。 我的复合版本有 1 个 python compound for循环和 2 个隐式循环.all.any
我的纯 numpy 版本没有显式循环,但有 3 个隐式 numpy 的嵌套循环(在 3d 数组的构建中)

所以同时结构。 只有 numpy 的循环更快。

我为我的纯 numpy 版本感到自豪,因为我一开始没有找到它。 但实际上,我的第一个版本(复合)更好。 仅当无关紧要时(对于非常小的阵列),它才会变慢。 它不消耗任何内存。 并且它只对外循环进行 numize,在内循环之前可以忽略不计。

长话短说:

sum((r>x).all(axis=1).any() for r in x)

除非你真的只有 4 行并且 μs 很重要,或者你正在参与谁可以在最纯粹的 numpy 3d-chess:D 中思考的竞赛,在这种情况下

(x.reshape(-1, 1, 3) > x.reshape(1, -1, 3)).all(axis=2).any(axis=1).sum()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM