[英]Apply channel shuffle in tensorflow (or keras)
I am trying to implement in tensorflow (or keras) a channel shuffle function.我正在尝试在 tensorflow(或 keras)中实现通道洗牌 function。 I have found this implementation but it seems to be wrong because I think it's based on this pytorch implementation .我找到了这个实现,但它似乎是错误的,因为我认为它基于这个pytorch 实现。
I have managed to do it with concatenate()
but I would like an implementation using permute_dimensions()
.我已经设法用concatenate()
做到了,但我想要一个使用permute_dimensions()
的实现。 Also, I am not sure if the concatenate version is slower (if someone can answer this one I would be grateful).另外,我不确定连接版本是否较慢(如果有人能回答这个问题,我将不胜感激)。
A working tensorflow implementation using concatenate()
:使用concatenate()
的工作 tensorflow 实现:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import models
import numpy as np
a = tf.constant([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
sess = tf.Session()
print('x', sess.run(a))
groups = 2 # separate into 2 group
h, w, in_channel = K.int_shape(a)[1:]
l = K.reshape(a, [-1, h, w, in_channel // groups, groups])
m = K.concatenate((l[..., 1], l[..., 0]))
l = K.reshape(m, [-1, h, w, in_channel])
print('y', sess.run(l))
with output:与 output:
x [[[[ 1 2]
[ 3 4]
[ 5 6]]
[[ 7 8]
[ 9 10]
[11 12]]]]
y [[[[ 2 1]
[ 4 3]
[ 6 5]]
[[ 8 7]
[10 9]
[12 11]]]]
A keras non working implementation is below: keras 非工作实现如下:
def channel_shuffle(x):
g = 2
b, h, w, c = x.shape.as_list()
x = K.reshape(x, [-1, h, w, g, c // g])
x = K.permute_dimensions(x, (0, 1, 2, 4, 3))
x = K.reshape(x, [-1, h, w, c])
return x
input_shape = (2, 3, 2)
x = np.array([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]])
image_input = layers.Input(input_shape)
permuted_x = channel_shuffle4(image_input)
model = models.Model(inputs=[image_input], outputs=permuted_x)
y = model.predict(x)
print('x', x)
print('y', y)
with output:与 output:
x [[[[ 1 2]
[ 3 4]
[ 5 6]]
[[ 7 8]
[ 9 10]
[11 12]]]]
y [[[[ 1. 2.]
[ 3. 4.]
[ 5. 6.]]
[[ 7. 8.]
[ 9. 10.]
[11. 12.]]]]
which obviously does not change the input data at all.这显然根本不会改变输入数据。 So, how can I achieve the desired result?那么,我怎样才能达到预期的效果呢? Basically which axes should I interchange?基本上我应该交换哪些轴? I have made some experiments but I cannot seem to find the right one.我做了一些实验,但似乎找不到合适的。
You need to perform a reverse of the last channel after permute_dimensions
.您需要在permute_dimensions
之后执行最后一个通道的反转。 permute_dimensions
is the same as tf.transpose
. permute_dimensions
与tf.transpose
相同。 Here is a solution that works directly on tensors:这是一个直接作用于张量的解决方案:
import tensorflow as tf
import numpy as np
def channel_shuffle(x):
g = 2
b, h, w, c = x.shape
x = tf.reshape(x, [-1, h, w, g, c // g])
x = tf.transpose(x, perm = [0, 1, 2, 4, 3])
x = tf.reverse(x,[-1])
x = tf.reshape(x, [-1, h, w, c])
return x
x = np.ones(shape = (1,2,2,4))
for c in range(4):
x[:,:,:,c] = c
y = channel_shuffle(x)
print(tf.__version__)
print("start:")
print(x)
print("result:")
print(y)
With output:使用 output:
2.3.1
start:
[[[[0. 1. 2. 3.]
[0. 1. 2. 3.]]
[[0. 1. 2. 3.]
[0. 1. 2. 3.]]]]
result:
tf.Tensor(
[[[[2. 0. 3. 1.]
[2. 0. 3. 1.]]
[[2. 0. 3. 1.]
[2. 0. 3. 1.]]]], shape=(1, 2, 2, 4), dtype=float64)
The logic works.逻辑有效。 But it's apparent when you pass tensors with channels greater than default value of 3 for images(or in your case with channels greater than 2).但是,当您传递通道大于图像默认值 3 的张量时(或者在您的情况下通道大于 2),很明显。 Think of cases when intermediate tensors after few convolution layers are flowing through channel shuffle portion.想想几个卷积层之后的中间张量流过通道混洗部分的情况。
I have written the demo code in numpy with 6 channels(multiple of 3).我在 numpy 中编写了演示代码,有 6 个通道(3 的倍数)。 You can convert it to tensorflow.您可以将其转换为 tensorflow。
>>a = np.arange(36).reshape(1, 2, 3, 6)
array([[[[ 0 1 2 3 4 5]
[ 6 7 8 9 10 11]
[12 13 14 15 16 17]]
[[18 19 20 21 22 23]
[24 25 26 27 28 29]
[30 31 32 33 34 35]]]])
>>g = 2
>>_,w,h,n = a.shape
>>nb_chn_per_grp = n//g
>>b = a.reshape(-1, w, h, g, nb_chn_per_grp)
>>c = b.transpose(0, 1, 2, 4, 3)
>>res = c.reshape(-1, w,h,n)
array([[[[ 0 3 1 4 2 5]
[ 6 9 7 10 8 11]
[12 15 13 16 14 17]]
[[18 21 19 22 20 23]
[24 27 25 28 26 29]
[30 33 31 34 32 35]]]])
We can clearly see the channel values have been shuffled as needed.我们可以清楚地看到通道值已根据需要进行了洗牌。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.