简体   繁体   English

如何使用矢量化快速解码 one-hot 编码的 NumPy 矩阵?

[英]How do I decode a one-hot encoded NumPy matrix in a fast manner using vectorization?

Given an image matrix of shape (height, width) with values in the uint8 range, which was one-hot encoded (converted to categorical) to a shape of (height, width, n) where n is the number of possible categories, 3 in this instance resulting in a shape of (height, width, 3) , I would like to undo the categorical conversion and get the original shape of (height, width) .给定形状为(height, width)且值在uint8范围内的图像矩阵,它被单热编码(转换为分类)为形状(height, width, n) ,其中 n 是可能类别的数量,3在这种情况下,形状为(height, width, 3) ,我想撤消分类转换并获得(height, width)的原始形状。 The following solution works, but could be made much faster:以下解决方案有效,但可以更快:

def decode(image):
    image = image

    height = image.shape[0]
    width = image.shape[1]

    decoded_image = numpy.ndarray(shape=(height, width), dtype=numpy.uint8)

    for i in range(0, height):
        for j in range(0, width):
            decoded_image[i][j] = numpy.argmax(image[i][j])

    return decoded_image

I would like a solution, using NumPy vectorization , without the need for a slower Python for loop .我想要一个解决方案,使用NumPy vectorization ,而不需要更慢的 Python for loop

Thank you for any suggestions.谢谢你的任何建议。

Looks like you want to do a reduction over the last dimension of your array, in particular a numpy.argmax .看起来您想减少数组的最后一个维度,特别是numpy.argmax Fortunately, this numpy function accepts an axis keyword, so you should be able to do the same in just one call:幸运的是,这个 numpy function 接受一个axis关键字,所以你应该能够在一个调用中做同样的事情:

decoded_image = numpy.argmax(image, axis=2)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM