繁体   English   中英

确定多热编码的有效性

[英]Determining the validity of a multi-hot encoding

假设我有N个项目和一个值{0, 1}的多热向量,表示在结果中包含这些项目:

N = 4

# items 1 and 3 will be included in the result
vector = [0, 1, 0, 1]

# item 2 will be included in the result
vector = [0, 0, 1, 0]

我还提供了一个冲突矩阵,指示哪些项目不能同时包含在结果中:

conflicts = [
  [0, 1, 1, 0], # any result that contains items 1 AND 2 is invalid
  [0, 1, 1, 1], # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
]

给定这个冲突矩阵,我们可以确定早期vector s 的有效性:

# invalid as it triggers conflict 1: [0, 1, 1, 1]
vector = [0, 1, 0, 1]

# valid as it triggers no conflicts
vector = [0, 0, 1, 0]

检测给vector是否“有效”(即不触发任何冲突)的简单解决方案可以通过 numpy 中的点积和求和运算来完成:

violation = np.dot(conflicts, vector)
is_valid = np.max(violation) <= 1

是否有更有效的方法来执行此操作,也许是通过np.einsum或绕过 numpy arrays 完全支持位操作?

我们假设被检查的向量的数量可能非常大(例如,如果我们评估所有可能性,则最多为2^N )但一次可能只检查一个向量(以避免生成形状高达(2^N, N)作为输入)。

TL;DR :您可以使用Numba优化np.dot以仅对二进制值进行操作。 更具体地说,您可以使用64 位视图一次对 8 个字节执行类似 SIMD 的操作




将列表转换为 arrays

首先,可以使用这种方法将列表有效地转换为相对紧凑的 arrays:

vector = np.fromiter(vector, np.uint8)
conflicts = np.array([np.fromiter(conflicts[i], np.uint8) for i in range(len(conflicts))])

这比使用自动 Numpy 转换或np.array (在内部 Numpy 代码和 Numpy 中执行的检查更少,Numpy 知道要构建什么类型的数组,并且生成的数组在 memory 中更小,因此填充速度更快) . 此步骤可用于加速基于np.dot的解决方案。

如果输入已经是 Numpy 数组,则检查它们的类型是np.uint8还是np.int8 否则,请使用conflits = conflits.astype(np.uint8)将它们转换为此类类型。


第一次尝试

然后,一种解决方案可能是使用np.packbits将输入的二进制值尽可能多地打包在 memory 中的位数组中,然后执行逻辑与操作。 但事实证明np.packbits非常慢。 因此,这个解决方案最终不是一个好主意。 事实上,任何创建形状类似于conflicts的临时 arrays 的解决方案都会很慢,因为在 memory 中写入这样的数组通常比np.dot (它从 memory 读取一次conflicts )慢。


使用 Numba

由于np.dot优化得很好,唯一的解决办法就是使用优化的本机代码。 得益于即时编译器,Numba 可用于在运行时从基于 Numpy 的 Python 代码生成本机可执行代码。 这个想法是在vector和每个块的conflicts行之间执行逻辑与。 对每个块进行冲突检查,以便尽早停止计算。 通过比较两个 arrays 的 uint64 视图(以 SIMD 友好的方式),可以按 8 个八位字节为一组有效地比较块。

import numba as nb

@nb.njit('bool_(uint8[::1], uint8[:,::1])')
def check_valid(vector, conflicts):
    n, m = conflicts.shape
    assert vector.size == m

    for i in range(n):
        block_size = 128 # In the range: 8,16,...,248
        conflicts_row = conflicts[i,:]
        gsum = 0 # Global sum of conflicts
        m_limit = m // block_size * block_size

        for j in range(0, m_limit, block_size):
            vector_block = vector[j:j+block_size].view(np.uint64)
            conflicts_block = conflicts_row[j:j+block_size].view(np.uint64)

            # Matching
            lsum = np.uint64(0) # 8 local sums of conflicts
            for k in range(block_size//8):
                lsum += vector_block[k] & conflicts_block[k]

            # Trick to perform the reduction of all the bytes in lsum
            lsum += lsum >> 32
            lsum += lsum >> 16
            lsum += lsum >> 8
            gsum += lsum & 0xFF

            # Check if there is a conflict
            if gsum >= 2:
                return False

        # Remaining part
        for j in range(m_limit, m):
            gsum += vector[j] & conflicts_row[j]

        if gsum >= 2:
            return False

    return True

结果

对于形状为(16, 65536)的大型conflicts数组(无冲突),这比我机器上的np.dot快 9 倍 两种情况都不包括转换列表的时间。 当存在冲突时,提供的解决方案要快得多,因为它可以提前停止计算。

理论上,计算应该更快,但 Numba JIT 不能成功地使用 SIMD 指令对循环进行矢量化。 话虽如此, np.dot似乎也出现了同样的问题。 如果 arrays 更大,您可以并行计算块(如果 function 返回 False,则计算速度会变慢)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM