简体   繁体   English

如何用另一个数组中的值替换 numpy 矩阵列中的值?

[英]How to replace values in numpy matrix columns with values from another array?

I have 2 numpy arrays:我有 2 个 numpy arrays:

a= np.array([[2, 1, 7],
             [7, 7, 3],
             [1, 7, 4]])


b= np.array([9,-1,17])

I would like to change the 7 in each column in a with the values from b , such that in the first column the 7 is replaced by the first value from b and in the second column the 7 s are replaced by the second value from b and so on.我想用b中的值更改a中每一列中的7 ,以便在第一列中,将7替换为b中的第一个值,在第二列中,将7替换为b中的第二个值等等。

My simple solution is:我的简单解决方案是:

for j in range(len(b)): 
    a[a[:,j]==7,j]=b[j]

array([[ 2,  1, 17],
       [ 9, -1,  3],
       [ 1, -1,  4]])

It works, but for very large matrices it is not fast enough.它有效,但对于非常大的矩阵,它不够快。 Is there another, faster way to do this?还有另一种更快的方法吗?

Assuming the computed matrices are big, you can implement a fast parallel version using Numba .假设计算矩阵很大,您可以使用 Numba 实现快速并行版本 This implementation is much faster than the initial solution using pure-Python loops which create many small temporary arrays and an inefficient non-contiguous memory access pattern (eg. a[:,j] ).此实现比使用纯 Python 循环的初始解决方案快得多,后者创建许多小的临时 arrays 和低效的非连续 memory 访问模式(例如a[:,j] )。 It is also significantly faster than using np.where(a == 7, b, a) due to the huge temporary arrays that needs to be filled and that may not fit in RAM (causing the OS to work with the very slow swap memory).它也比使用np.where(a == 7, b, a)快得多,因为需要填充巨大的临时 arrays并且可能不适合 RAM (导致操作系统使用非常慢的交换内存)。 Using multiple threads also provides a big speed up.使用多线程也提供了很大的加速。 Here is the code:这是代码:

import numba as nb

@nb.njit('void(int_[:,::1], int_[::1])', parallel=True)
def compute(a, b):
    n, m = a.shape
    assert b.size == m
    for i in nb.prange(n):
        for j in range(m):
            if a[i,j] == 7:
                a[i,j] = b[j]

Here are results on my 6-core machine on a 100000x1000 matrix (with random 32-bit integers in 0..10):以下是我在 100000x1000 矩阵上的 6 核机器上的结果(随机 32 位整数在 0..10 中):

For loop:  683 ms
np.where:  169 ms
Numba:      37 ms

This version is 18 times faster than the initial version and takes almost no more memory (it works in-place).此版本比初始版本快 18 倍,并且几乎不再需要 memory(它就地工作)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM