繁体   English   中英

在python中扩展列表(效率)的最佳实践

[英]Best practice to expand a list (efficiency) in python

我正在处理大型数据集。 我正在尝试使用 NumPy 库或 python 功能以有效的方式(例如 LC)处理数据集。

首先我找到相关的索引:

dt_temp_idx = np.where(dt_diff > dt_temp_th)

然后我想为每个索引创建一个包含从索引到停止值的序列的掩码,我试过:

mask_dt_temp = [np.arange(idx, idx+dt_temp_step) for idx in dt_temp_idx]

和:

  mask_dt_temp = [idxs for idx in dt_temp_idx for idxs in np.arange(idx, idx+dt_temp_step)]

但它给了我一个例外:

The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

示例输入:

indexes = [0, 100, 1000]

每个索引的 10 个整数后带有停止值的示例输出:

list = [0, 1, ..., 10, 100, 101, ..., 110, 1000, 1001, ..., 1010]

1)我该如何解决? 2)这是最好的做法吗?

使用掩码(布尔数组)是高效的,内存效率和性能也是如此。 我们将利用SciPy's binary-dilation来扩展阈值掩码。

这是一步一步的设置和解决方案运行 -

In [42]: # Random data setup
    ...: np.random.seed(0)
    ...: dt_diff = np.random.rand(20)
    ...: dt_temp_th = 0.9

In [43]: # Get mask of threshold crossings
    ...: mask = dt_diff > dt_temp_th

In [44]: mask
Out[44]: 
array([False, False, False, False, False, False, False, False,  True,
       False, False, False, False,  True, False, False, False, False,
       False, False])

In [45]: W = 3 # window size for extension (edit it according to your use-case)

In [46]: from scipy.ndimage.morphology import binary_dilation

In [47]: extm = binary_dilation(mask, np.ones(W, dtype=bool), origin=-(W//2))

In [48]: mask
Out[48]: 
array([False, False, False, False, False, False, False, False,  True,
       False, False, False, False,  True, False, False, False, False,
       False, False])

In [49]: extm
Out[49]: 
array([False, False, False, False, False, False, False, False,  True,
        True,  True, False, False,  True,  True,  True, False, False,
       False, False])

maskextm进行比较以了解扩展是如何发生的。

因为,我们可以看到阈值mask在右侧扩展了窗口大小W ,预期的输出掩码extm 这可用于屏蔽输入数组中的那些: dt_diff[~extm]以模拟从输入中删除/删除以下boolean-indexing或相反的dt_diff[extm]以模拟选择这些dt_diff[extm]

基于 NumPy 的函数的替代方案

替代方案#1

extm = np.convolve(mask, np.ones(W, dtype=int))[:len(dt_diff)]>0

替代方案#2

idx = np.flatnonzero(mask)
ext_idx = (idx[:,None]+ np.arange(W)).ravel()

ext_mask = np.ones(len(dt_diff), dtype=bool)
ext_mask[ext_idx[ext_idx<len(dt_diff)]] = False
 
# Get filtered o/p
out = dt_diff[ext_mask]

dt_temp_idx是一个 numpy 数组,但仍然是一个 Python 可迭代的,所以你可以使用一个很好的旧 Python 列表理解:

lst = [ i for j in dt_temp_idx for i in range(j, j+11)]

如果您想处理序列重叠并将其恢复为 np.array,只需执行以下操作:

result = np.array({i for j in dt_temp_idx for i in range(j, j+11)})

但是要注意集合的使用是健壮的并且保证不会重复,但它可能比简单的列表更昂贵。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM