[英]How to compose permutations in numpy efficiently?
I have bijections/permutations on a 3d numpy array given as two 3 tuples of numpy arrays.我在 3d numpy 数组上有双射/排列,以两个 3 元组的 numpy 数组形式给出。 For example my 3d numpy array could look like this
例如,我的 3d numpy 数组可能如下所示
arr = np.array([[[3, 2, 1], [5, 0, 5], [2, 0, 1]],
[[3, 4, 5], [4, 2, 0], [0, 1, 1]],
[[2, 0, 5], [1, 5, 1], [0, 5, 1]],
[[4, 3, 0], [1, 3, 3], [3, 3, 3]],
[[2, 4, 0], [2, 1, 0], [4, 4, 4]],
[[4, 3, 2], [2, 4, 2], [5, 5, 5]]])
And a permutation on it like this:和它的排列是这样的:
a, b = ((np.array([5, 2, 2, 2, 1, 1, 1, 4, 4, 4, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0]),
np.array([0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 2, 1, 0, 0, 0, 1, 2, 2, 2, 1]),
np.array([2, 2, 1, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 1, 2, 2, 2, 1, 0, 0])),
(np.array([4, 5, 5, 5, 2, 2, 2, 1, 1, 1, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0]),
np.array([0, 2, 1, 0, 0, 0, 0, 0, 1, 2, 0, 0, 2, 1, 0, 0, 0, 1, 2, 2]),
np.array([2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 1])))
I can apply the map by doing arr[a] = arr[b]
.我可以通过执行
arr[a] = arr[b]
来应用地图。
My question is: is there an efficient way to compose two of those bijections?我的问题是:有没有一种有效的方法来组合其中两个双射? For example I want a function
compose
for which the two following statements are equivalent例如,我想要一个函数
compose
,它的以下两个语句是等效的
c, d = compose((a, b), (a, b))
arr[c] = arr[d]
and和
arr[a] = arr[b]
arr[a] = arr[b]
The idea here is very much like any other resampler: for each destination location, find the corresponding source location.这里的想法与任何其他重采样器非常相似:对于每个目标位置,找到相应的源位置。 Let's start with a simplified case.
让我们从一个简化的案例开始。 Say we have a 1-D array and corresponding indices:
假设我们有一个一维数组和相应的索引:
arr1 = np.arange(6)
a1 = np.array([3, 4])
b1 = np.array([2, 3])
It's much easier to visualize what happens when you make the assignments:当您进行分配时,更容易想象会发生什么:
arr[a1] = arr[b1] # [0, 1, 2, 2, 3, 5]
arr[a1] = arr[b1] # [0, 1, 2, 2, 2, 5]
For a destination index that is the entire array (eg, np.arange(arr.size)
), the source index is just the application of the assignment to the destination.对于作为整个数组的目标索引(例如,
np.arange(arr.size)
),源索引只是对目标的赋值的应用。
There are a couple of ways of generalizing this to a 3D array.有几种方法可以将其推广到 3D 数组。 One way would be to make a
(*arr.shapes, arr.ndim)
array containing a meshgrid of all the indices.一种方法是制作一个
(*arr.shapes, arr.ndim)
数组,其中包含所有索引的网格。 Another is simply to convert a
and b
into raveled (linear) indices into the raveled version of arr
.另一个是简单地将
a
和b
转换为arr
的 raveled 版本(线性)索引。 I recommend going with the latter.我建议选择后者。
a_r = np.ravel_multi_index(a, arr.shape)
b_r = np.ravel_multi_index(b, arr.shape)
You can construct a source and destination index pair of the same size as arr.ravel()
, and set the sources and destinations in it:您可以构造一个与
arr.ravel()
大小相同的源和目标索引对,并在其中设置源和目标:
dest = np.arange(arr.size)
src = np.arange(arr.size)
src[a] = src[b]
I wrote out the last assignment "in full", although the first time you can just do src[a] = b
.我写了“完整”的最后一个作业,虽然第一次你可以只做
src[a] = b
。 You can iterate the last assignment as many times as you want.您可以根据需要多次迭代最后一个任务。 In your particular example, do it a second time:
在您的特定示例中,再做一次:
src[a] = src[b]
If you want, you can trim off the elements that are not modified by the assignment:如果需要,您可以修剪掉未被赋值修改的元素:
mask = dest != src
dest = dest[mask]
src = src[mask]
Finally, you can unravel the index back to the original shape:最后,您可以将索引解开回原始形状:
c = np.unravel_index(dest, arr.shape)
d = np.unravel_index(src, arr.shape)
If you want to write a function that accepts an arbitrary number of input indices, it might look something like this:如果您想编写一个接受任意数量输入索引的函数,它可能如下所示:
def compose(*args, shape=None, sparse=True):
if shape is None:
get_max = lambda tup: np.array([arr.max() for arr in tup])
for a, b in args:
if shape is None:
shape = np.maximum(get_max(a), get_max(b))
else:
shape = np.maximum(shape, get_max(a))
shape = np.maximum(shape, get_max(b))
shape = tuple(shape + 1)
size = np.prod(shape)
dest = np.arange(size)
src = np.arange(size)
for a, b in args:
a = np.ravel_multi_index(a, shape)
b = np.ravel_multi_index(b, shape)
src[a] = src[b]
if sparse:
mask = dest != src
dest = dest[mask]
src = src[mask]
return np.unravel_index(dest, shape), np.unravel_index(src, shape)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.