简体   繁体   English

使用 Tensorflow 计算像素邻域中的像素 colors

[英]Count pixel colors in pixel neighborhood with Tensorflow

I have an image, of shape [1024, 1024, 3] in a reduced color space (like ~100 colors).我有一个形状为[1024, 1024, 3]的图像,颜色空间减少(如约 100 种颜色)。 I need to count, for every possible couple of colors, how many times one appears in the neighborhood of the other.我需要计算,对于每一对可能的 colors,一个出现在另一个附近的次数。 This is quite simple in python, the problem I have is to write it using Tensorflow functions.这在 python 中非常简单,我遇到的问题是使用 Tensorflow 函数编写它。 I will make a simple example to make it easier to understand:我将做一个简单的示例以使其更易于理解:

Let`s consider a color space with only 4 colors (named 1,2,3,4 ) and a 4x4 image:让我们考虑一个只有 4 个 colors (命名为1,2,3,4 )和一个4x4图像的颜色空间:

[ [2 4 1 4]  
  [2 4 2 1]  
  [2 4 3 1]  
  [2 1 1 2] ]

From this I have to compute the neighborhood color histogram , that explain what is written above: how many pixel of every color appears in the neighborhood of every pixel of every other color.由此我必须计算邻域颜色直方图,这解释了上面写的内容:每种颜色有多少像素出现在其他颜色的每个像素的邻域中。 The NCH of the above image is the following:上图的NCH如下:

[ [8 7 4 6]  
  [7 6 2 12]  
  [4 2 0 2]  
  [6 12 2 4] ]

To make it clear, the 8 in the first row means that color 1 appears in the neighborhoods of color 1 8 times.为了清楚起见,第一行中的 8 表示颜色 1出现在颜色1 的邻域中 8 次。 The 7 indicates that color 2 appears in the neighborhoods of color 1 7 times and so on. 7 表示颜色 2颜色 1的邻域出现 7 次,以此类推。

In this simple example I consider a neighborhood of size 3x3 , but the code I am working on should be extendable to a generic NxN size.在这个简单的例子中,我考虑了一个大小为3x3的邻域,但我正在处理的代码应该可以扩展到一个通用的NxN大小。

The only step I was able to implement with Tensorflow is to split each image in the batch (I am implementing this as the first step of a loss function, so it has to work with batches) in NxN patches, centered on every pixel, using the following function:我能够使用 Tensorflow 实现的唯一步骤是拆分批次中的每个图像(我将其实现为损失 function 的第一步,因此它必须与批次一起使用)在NxN补丁中,以每个像素为中心,使用以下 function:

patches = tf.image.extract_patches(image, sizes=(1, D_size, D_size, 1), strides=(1, 1, 1, 1),
                                         padding='SAME', rates=[1, 1, 1, 1])

What I have to do now is to iterate over every patch and increment, according to the central color of the patch, the NCH matrix.我现在要做的是迭代每个补丁并根据补丁的中心颜色 NCH 矩阵递增。 This is very simple using loops, but I struggle to find the right Tensorflow functions to execute them in a more parallel way.使用循环非常简单,但我很难找到正确的 Tensorflow 函数以更并行的方式执行它们。

I am pretty new to the tensorflow world, so I understand it may seems an obvious question, but I really do not know what else to try and I will prefer to not use loops.我对 tensorflow 世界还很陌生,所以我知道这似乎是一个显而易见的问题,但我真的不知道还能尝试什么,我宁愿不使用循环。 Thank you all in advance.谢谢大家。

Here's a way of doing it.这是一种方法。 I was expecting this to some out at about 5 lines of code, but it's turned out a lot longer.我原以为这会在大约 5 行代码中出现,但结果却要长得多。 (I've left in some diagnostic print statements, might help understand what it is doing) If anyone knows a more elegant way I'd be very interested. (我留下了一些诊断打印语句,可能有助于理解它在做什么)如果有人知道更优雅的方式,我会非常感兴趣。

As Alberto has said, probably don't expect to be able to backprop through this.正如阿尔贝托所说,可能不要指望能够通过这个进行反向传播。

def count_neighbours(image, ncolours=5, adjacency=1):
    '''
    Note colours have to be numbers 1 to ncolours, NOT zero
    Zero values are created during the calc, both in padding and in tensor products,
    and must be discarded
    '''
    f = 2 * adjacency + 1      # filter size
    ic = adjacency * (f + 1)   # Index of central pixel in flattened patch
    print(f"{adjacency=} {ic=} {f=}")
    print(f"{image.shape=}")

    # Get patches. Returns location of patch in dims 1, 2, and (flattened) patch in 3
    patches = tf.image.extract_patches(image, sizes=(1, f, f, 1), strides=(1, 1, 1, 1),
                                         padding='SAME', rates=[1, 1, 1, 1])
    print(f"{patches.shape=}") 
    
    # A sort of outer product with one-hot encoding of the central pixel in a patch
    # So if a patch [i, j, k, :] has central pixel l, it is copied to [i, j, k, :, l]
    # First, create a one-hot encoding of the central pixel in each patch
    oh_central = tf.one_hot(patches[:,:,:,ic], axis=-1, depth=ncolours + 1, dtype=tf.int32)  # one-hot of central pixel in each patch
    print(f"{oh_central.shape=}") 

    # Want oh1[m, h, w, ipixel, colour] = patches[m, h, w, ipixel] * oh_central[m, h, w, colour]
    oh1 = tf.einsum('ijkl,ijkm->ijklm',patches, oh_central)
    print(f"{oh1.shape=}")
    
    # One-hot encode the patches
    oh2 = tf.one_hot(oh1, axis=-1, depth=ncolours + 1, dtype=tf.int32)
    print(f"{oh2.shape=}")
    
    # Set central pixels to zero, since a pixel is not counted as being adjacent to itself
    mask = tf.concat([
            tf.ones((oh2.shape[0], oh2.shape[1], oh2.shape[2], oh2.shape[3] //2, oh2.shape[4], oh2.shape[5]), dtype=tf.int32),
            tf.zeros((oh2.shape[0], oh2.shape[1], oh2.shape[2], 1, oh2.shape[4], oh2.shape[5]), dtype=tf.int32),
            tf.ones((oh2.shape[0], oh2.shape[1], oh2.shape[2], oh2.shape[3] //2, oh2.shape[4], oh2.shape[5]), dtype=tf.int32),
                ], axis=3)
    assert mask.shape == oh2.shape
    oh3 = oh2 * mask

    # finally, reduce sum, discard counts of zeros and return result
    result = tf.reduce_sum(oh3, axis=[1,2,3])[:,1:,1:]
    return result

Then this can be called like this, and gives the results you posted.然后可以这样调用,并给出您发布的结果。

x = tf.constant([ [2, 4, 1, 4]  ,
  [2, 4, 2, 1]  ,
  [2, 4, 3, 1]  ,
  [2, 1, 1, 2] ], dtype=tf.int32)
x = tf.expand_dims(tf.expand_dims(x,0),3)  # add dummy batch and channel dimensions
count_neighbours(x, ncolours=4)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM