简体   繁体   English

反射填充 Conv2D

[英]Reflection padding Conv2D

I'm using keras to build a convolutional neural network for image segmentation and I want to use "reflection padding" instead of padding "same" but I cannot find a way to to do it in keras.我正在使用 keras 构建一个用于图像分割的卷积神经网络,我想使用“反射填充”而不是“相同”填充,但我找不到在 keras 中做到这一点的方法。

inputs = Input((num_channels, img_rows, img_cols))
conv1=Conv2D(32,3,padding='same',kernel_initializer='he_uniform',data_format='channels_first')(inputs)

Is there a way to implement a reflection layer and insert it in a keras model ?有没有办法实现反射层并将其插入到 keras 模型中?

The accepted answer above is not working in the current Keras version.上面接受的答案在当前的 Keras 版本中不起作用。 Here is the version that's working:这是有效的版本:

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

Found the solution!找到了解决办法! We have only to create a new class that takes a layer as input and use tensorflow predefined function to do it.我们只需要创建一个以层作为输入的新类,并使用 tensorflow 预定义函数来完成它。

import tensorflow as tf
from keras.engine.topology import Layer
from keras.engine import InputSpec

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def get_output_shape_for(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

# a little Demo
inputs = Input((img_rows, img_cols, num_channels))
padded_inputs= ReflectionPadding2D(padding=(1,1))(inputs)
conv1 = Conv2D(32, 3, padding='valid', kernel_initializer='he_uniform',
               data_format='channels_last')(padded_inputs)
import tensorflow as tf
from keras.layers import Lambda

inp_padded = Lambda(lambda x: tf.pad(x, [[0,0], [27,27], [27,27], [0,0]], 'REFLECT'))(inp)

The solution from Akihiko did not work with the new keras version, so I came up with my own. Akihiko 的解决方案不适用于新的 keras 版本,所以我想出了自己的解决方案。 The snippet pads a batch of 202x202x3 images to 256x256x3该代码段将一批 202x202x3 的图像填充到 256x256x3

As you can check in the documentation there is no such 'reflect' padding.正如您可以在文档中查看的那样,没有这样的“反射”填充。 Only 'same' and 'valid' are implemented in keras.在 keras 中只实现了“相同”和“有效”。

You maybe try to implemented on your own or find if somebody already did it.您可能会尝试自己实施或查找是否有人已经这样做了。 You should base yourself in the Conv2D class and check where self.padding member variable is used.您应该基于Conv2D类并检查使用self.padding成员变量的位置。

The accepted answer does not work if we have undefined dimensions!如果我们有未定义的维度,则接受的答案不起作用! There will be an error when compute_output_shape function is called.调用compute_output_shape函数时会报错。 Here is the simple work around to that.这是解决这个问题的简单方法。

class ReflectionPadding2D(Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        if s[1] == None:
            return (None, None, None, s[3])
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')

    def get_config(self):
        config = super(ReflectionPadding2D, self).get_config()
        print(config)
        return config

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM