[英]Best practice to expand a list (efficiency) in python
我正在处理大型数据集。 我正在尝试使用 NumPy 库或 python 功能以有效的方式(例如 LC)处理数据集。
首先我找到相关的索引:
dt_temp_idx = np.where(dt_diff > dt_temp_th)
然后我想为每个索引创建一个包含从索引到停止值的序列的掩码,我试过:
mask_dt_temp = [np.arange(idx, idx+dt_temp_step) for idx in dt_temp_idx]
和:
mask_dt_temp = [idxs for idx in dt_temp_idx for idxs in np.arange(idx, idx+dt_temp_step)]
但它给了我一个例外:
The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
示例输入:
indexes = [0, 100, 1000]
每个索引的 10 个整数后带有停止值的示例输出:
list = [0, 1, ..., 10, 100, 101, ..., 110, 1000, 1001, ..., 1010]
1)我该如何解决? 2)这是最好的做法吗?
使用掩码(布尔数组)是高效的,内存效率和性能也是如此。 我们将利用SciPy's binary-dilation
来扩展阈值掩码。
这是一步一步的设置和解决方案运行 -
In [42]: # Random data setup
...: np.random.seed(0)
...: dt_diff = np.random.rand(20)
...: dt_temp_th = 0.9
In [43]: # Get mask of threshold crossings
...: mask = dt_diff > dt_temp_th
In [44]: mask
Out[44]:
array([False, False, False, False, False, False, False, False, True,
False, False, False, False, True, False, False, False, False,
False, False])
In [45]: W = 3 # window size for extension (edit it according to your use-case)
In [46]: from scipy.ndimage.morphology import binary_dilation
In [47]: extm = binary_dilation(mask, np.ones(W, dtype=bool), origin=-(W//2))
In [48]: mask
Out[48]:
array([False, False, False, False, False, False, False, False, True,
False, False, False, False, True, False, False, False, False,
False, False])
In [49]: extm
Out[49]:
array([False, False, False, False, False, False, False, False, True,
True, True, False, False, True, True, True, False, False,
False, False])
将mask
与extm
进行比较以了解扩展是如何发生的。
因为,我们可以看到阈值mask
在右侧扩展了窗口大小W
,预期的输出掩码extm
。 这可用于屏蔽输入数组中的那些: dt_diff[~extm]
以模拟从输入中删除/删除以下boolean-indexing
或相反的dt_diff[extm]
以模拟选择这些dt_diff[extm]
。
替代方案#1
extm = np.convolve(mask, np.ones(W, dtype=int))[:len(dt_diff)]>0
替代方案#2
idx = np.flatnonzero(mask)
ext_idx = (idx[:,None]+ np.arange(W)).ravel()
ext_mask = np.ones(len(dt_diff), dtype=bool)
ext_mask[ext_idx[ext_idx<len(dt_diff)]] = False
# Get filtered o/p
out = dt_diff[ext_mask]
dt_temp_idx
是一个 numpy 数组,但仍然是一个 Python 可迭代的,所以你可以使用一个很好的旧 Python 列表理解:
lst = [ i for j in dt_temp_idx for i in range(j, j+11)]
如果您想处理序列重叠并将其恢复为 np.array,只需执行以下操作:
result = np.array({i for j in dt_temp_idx for i in range(j, j+11)})
但是要注意集合的使用是健壮的并且保证不会重复,但它可能比简单的列表更昂贵。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.