简体   繁体   English

在 tensorflow(或 keras)中应用通道洗牌

[英]Apply channel shuffle in tensorflow (or keras)

I am trying to implement in tensorflow (or keras) a channel shuffle function.我正在尝试在 tensorflow(或 keras)中实现通道洗牌 function。 I have found this implementation but it seems to be wrong because I think it's based on this pytorch implementation .我找到了这个实现,但它似乎是错误的,因为我认为它基于这个pytorch 实现

I have managed to do it with concatenate() but I would like an implementation using permute_dimensions() .我已经设法用concatenate()做到了,但我想要一个使用permute_dimensions()的实现。 Also, I am not sure if the concatenate version is slower (if someone can answer this one I would be grateful).另外,我不确定连接版本是否较慢(如果有人能回答这个问题,我将不胜感激)。

A working tensorflow implementation using concatenate() :使用concatenate()的工作 tensorflow 实现:

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import models
import numpy as np

a = tf.constant([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
sess = tf.Session()
print('x', sess.run(a))
groups = 2  # separate into 2 group
h, w, in_channel = K.int_shape(a)[1:]
l = K.reshape(a, [-1, h, w, in_channel // groups, groups])
m = K.concatenate((l[..., 1], l[..., 0]))
l = K.reshape(m, [-1, h, w, in_channel])
print('y', sess.run(l))

with output:与 output:

x [[[[ 1  2]
   [ 3  4]
   [ 5  6]]
  [[ 7  8]
   [ 9 10]
   [11 12]]]]
y [[[[ 2  1]
   [ 4  3]
   [ 6  5]]
  [[ 8  7]
   [10  9]
   [12 11]]]]

A keras non working implementation is below: keras 非工作实现如下:

def channel_shuffle(x):
    g = 2
    b, h, w, c = x.shape.as_list()
    x = K.reshape(x, [-1, h, w, g, c // g])
    x = K.permute_dimensions(x, (0, 1, 2, 4, 3))
    x = K.reshape(x, [-1, h, w, c])
    return x

input_shape = (2, 3, 2)
x = np.array([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
image_input = layers.Input(input_shape)
permuted_x = channel_shuffle4(image_input)
model = models.Model(inputs=[image_input], outputs=permuted_x)
y = model.predict(x)
print('x', x)
print('y', y)

with output:与 output:

x [[[[ 1  2]
   [ 3  4]
   [ 5  6]]

  [[ 7  8]
   [ 9 10]
   [11 12]]]]
y [[[[ 1.  2.]
   [ 3.  4.]
   [ 5.  6.]]

  [[ 7.  8.]
   [ 9. 10.]
   [11. 12.]]]]

which obviously does not change the input data at all.这显然根本不会改变输入数据。 So, how can I achieve the desired result?那么,我怎样才能达到预期的效果呢? Basically which axes should I interchange?基本上我应该交换哪些轴? I have made some experiments but I cannot seem to find the right one.我做了一些实验,但似乎找不到合适的。

You need to perform a reverse of the last channel after permute_dimensions .您需要在permute_dimensions之后执行最后一个通道的反转。 permute_dimensions is the same as tf.transpose . permute_dimensionstf.transpose相同。 Here is a solution that works directly on tensors:这是一个直接作用于张量的解决方案:

import tensorflow as tf 
import numpy as np

def channel_shuffle(x):
    g = 2
    b, h, w, c = x.shape
    x = tf.reshape(x, [-1, h, w, g, c // g])
    x = tf.transpose(x, perm = [0, 1, 2, 4, 3])
    x = tf.reverse(x,[-1])
    x = tf.reshape(x, [-1, h, w, c])
    return x

x = np.ones(shape = (1,2,2,4))
for c in range(4):
    x[:,:,:,c] = c

y = channel_shuffle(x)
print(tf.__version__)
print("start:")
print(x)
print("result:")
print(y)

With output:使用 output:

2.3.1
start:
[[[[0. 1. 2. 3.]
   [0. 1. 2. 3.]]

  [[0. 1. 2. 3.]
   [0. 1. 2. 3.]]]]
result:
tf.Tensor(
[[[[2. 0. 3. 1.]
   [2. 0. 3. 1.]]

  [[2. 0. 3. 1.]
   [2. 0. 3. 1.]]]], shape=(1, 2, 2, 4), dtype=float64)

The logic works.逻辑有效。 But it's apparent when you pass tensors with channels greater than default value of 3 for images(or in your case with channels greater than 2).但是,当您传递通道大于图像默认值 3 的张量时(或者在您的情况下通道大于 2),很明显。 Think of cases when intermediate tensors after few convolution layers are flowing through channel shuffle portion.想想几个卷积层之后的中间张量流过通道混洗部分的情况。

I have written the demo code in numpy with 6 channels(multiple of 3).我在 numpy 中编写了演示代码,有 6 个通道(3 的倍数)。 You can convert it to tensorflow.您可以将其转换为 tensorflow。

>>a = np.arange(36).reshape(1, 2, 3, 6)
array([[[[ 0  1  2  3  4  5]
         [ 6  7  8  9 10 11]
         [12 13 14 15 16 17]]

        [[18 19 20 21 22 23]
         [24 25 26 27 28 29]
         [30 31 32 33 34 35]]]])
>>g = 2
>>_,w,h,n = a.shape
>>nb_chn_per_grp = n//g
>>b = a.reshape(-1, w, h, g, nb_chn_per_grp)
>>c = b.transpose(0, 1, 2, 4, 3)
>>res = c.reshape(-1, w,h,n)

array([[[[ 0  3  1  4  2  5]
         [ 6  9  7 10  8 11]
         [12 15 13 16 14 17]]

        [[18 21 19 22 20 23]
         [24 27 25 28 26 29]
         [30 33 31 34 32 35]]]])

We can clearly see the channel values have been shuffled as needed.我们可以清楚地看到通道值已根据需要进行了洗牌。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM