繁体   English   中英

对 numpy 数组的许多随机排列进行采样的最快方法

[英]fastest way to sample many random permutations of a numpy array

与许多其他 numpy/随机函数不同, numpy.random.Generator.permutation()没有提供在单个 function 调用中返回多个结果的明显方法。 给定一个 (1d) numpy 数组x ,我想对xn个排列进行采样(每个排列的长度为 len(x)),并将结果作为形状为(n, len(x))的 numpy 数组。 生成许多排列的一种天真的方法是np.array([rng.permutation(x) for _ in range(n)]) 这并不理想,主要是因为循环在 Python 中,而不是在已编译的 numpy function 中。

import numpy as np

rng = np.random.default_rng(1234)
# x is big enough to not want to enumerate all permutations
x = rng.standard_normal(size=20)
n = 10000
perms = np.array([rng.permutation(x) for _ in range(n)])

我的用例是用于蛮力搜索以找到最小化特定属性的排列(构成“足够好”的搜索解决方案)。 我可以使用 numpy 操作计算每个排列的感兴趣属性,这些操作可以很好地矢量化/广播生成的排列矩阵。 事实证明,天真地生成排列矩阵是我代码中的瓶颈。 有没有更好的办法?

您可以使用rng.permuted而不是rng.permutation并将其与np.tile组合,以便多次重复x并独立地打乱每个复制。 方法如下:

perms = rng.permuted(np.tile(x, n).reshape(n,x.size), axis=1)

这在我的机器上比您的初始代码快 10 倍。

请注意,Jérome 的解决方案提供了一个“n”行数组,但它可能包含重复项。 不同的行可能具有相同的“x”顺序(特别是如果“n”大于“x”)

如果您需要在不重复的情况下进行采样(就像我的情况一样),您可以随时执行set(list(perm))并保留唯一组合“x”值

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM