繁体   English   中英

快速计算非空区域的方法

[英]A fast way to count non-empty regions

我正在编写一些代码,该代码在通过原点的5个维度中选择n随机超平面。 然后,它在单位球面上随机地均匀采样no_points点,并计算由超平面创建的区域中有多少个至少具有一个点。 使用下面的Python代码,这相对简单。

import numpy as np

def points_on_sphere(dim, N, norm=np.random.normal):
    """
    http://en.wikipedia.org/wiki/N-sphere#Generating_random_points
    """
    normal_deviates = norm(size=(N, dim))
    radius = np.sqrt((normal_deviates ** 2).sum(axis=0))
    points = normal_deviates / radius
    return points

n = 100
d = 5
hpoints = points_on_sphere(n, d).T
for no_points in xrange(0, 10000000,100000):
    test_points = points_on_sphere(no_points,d).T 
    #The next two lines count how many of the test_points are in different regions created by the hyperplanes
    signs = np.sign(np.inner(test_points, hpoints))
    print no_points, len(set(map(tuple,signs)))

不幸的是,我幼稚的计算不同区域中的点数的方法很慢。 总体而言,该方法花费O(no_points * n * d)时间,实际上,一旦no_points达到大约1000000它就会变得太慢且RAM太饿。 特别是在no_points = 900,000时达到4GB RAM。

是否可以更有效地完成此操作,以便no_points可以相当快地使用不到4GB的内存, no_points 10,000,000(实际上,如果达到10倍,那会很棒)?

存储每个测试点相对于每个超平面的分类方式的数据很多。 我建议在点标签上使用隐式基数排序,例如,

import numpy as np


d = 5
n = 100
N = 100000
is_boundary = np.zeros(N, dtype=bool)
tpoints = np.random.normal(size=(N, d))
tperm = np.arange(N)
for i in range(n):
    hpoint = np.random.normal(size=d)
    region = np.cumsum(is_boundary) * 2 + (np.inner(hpoint, tpoints) < 0.0)[tperm]
    region_order = np.argsort(region)
    is_boundary[1:] = np.diff(region[region_order])
    tperm = tperm[region_order]
print(np.sum(is_boundary))

此代码保留测试点( tperm )的排列,以使同一区域中的所有点都是连续的。 boundary指示每个点是否在排列顺序上与先前的区域不同。 对于每个连续的超平面,我们对每个现有区域进行分区,并有效地丢弃空区域,以避免存储其中的2 ^ 100。

实际上,由于您有很多点而超平面太少,所以不存储这些点更有意义。 以下代码使用二进制编码将区域签名打包为两个双精度型。

import numpy as np


d = 5
hpoints = np.random.normal(size=(100, d))
bits = np.zeros((2, 100))
bits[0, :50] = 2.0 ** np.arange(50)
bits[1, 50:] = 2.0 ** np.arange(50)
N = 100000
uniques = set()
for i in xrange(0, N, 1000):
    tpoints = np.random.normal(size=(1000, d))
    signatures = np.inner(np.inner(tpoints, hpoints) < 0.0, bits)
    uniques.update(map(tuple, signatures))
print(len(uniques))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM