简体   繁体   English

将 pytorch 2d 填充转换为 tensorflow keras

[英]converting pytorch 2d padding to tensorflow keras

What would be the equivalent of this:这相当于什么:

nn.ReflectionPad2d(1)

in TensorFlow 2?在 TensorFlow 2 中? The above line is from PyTorch.上面的行来自 PyTorch。

In TensorFlow 2 Keras, I'm currently looking into using tf.pad() as a TF version of that, but it seems PyTorch is able to handle varying dimensions with that single integer 1. For instance, if it gets an input of shape [batch size, 1, 1, 100], nn.ReflectionPad2D will handle that well, but in TensorFlow, I get an error if I try to use In TensorFlow 2 Keras, I'm currently looking into using tf.pad() as a TF version of that, but it seems PyTorch is able to handle varying dimensions with that single integer 1. For instance, if it gets an input of shape [batch size, 1, 1, 100], nn.ReflectionPad2D 会处理得很好,但是在 TensorFlow 中,如果我尝试使用会出错

tf.pad(t, tf.constant([0,0], [1,1], [1,1], [0,0]]), 'REFLECT')

Any suggestions on how to replicate nn.ReflectinPad2d in TensorFlow 2 keras?关于如何在 TensorFlow 2 keras 中复制 nn.ReflectinPad2d 的任何建议? Thanks!谢谢!

When I was training CycleGan on TF2, I created this custom layer for myself:当我在 TF2 上训练 CycleGan 时,我为自己创建了这个自定义层:

class ReflectionPad2D(tf.keras.layers.Layer):
  def __init__(self, paddings=(1,1,1,1)):
    super(ReflectionPad2D, self).__init__()
    self.paddings = paddings

  def call(self, input):
    l, r, t, b = self.paddings

    return tf.pad(input, paddings=[[0,0], [t,b], [l,r], [0,0]], mode='REFLECT')

You can just use it as is by putting it in a Model/Sequential, eg:您可以通过将其放入模型/序列中来按原样使用它,例如:

model = Sequential([
    ReflectionPad2D((3, 3, 3, 3)),
    Conv2D(64, kernel_size=7, strides=1, padding='valid', use_bias=False),
    BatchNormalization()
])

model.build(input_shape=(8, 128, 128, 3))
model.summary()

Example output:示例 output:

Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reflection_pad2d_198 (Reflec (8, 134, 134, 3)          0         
_________________________________________________________________
conv2d_474 (Conv2D)          (8, 128, 128, 64)         9408      
_________________________________________________________________
batch_normalization_41 (Batc (8, 128, 128, 64)         256       
=================================================================
Total params: 9,664
Trainable params: 9,536
Non-trainable params: 128
_________________________________________________________________

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM