简体   繁体   中英

Tensorflow (Keras) to PyTorch Conv2D conversion

I'm trying to translate a custom UNET implementation from Tensorflow to PyTorch. I've encountered some problems with the Conv2D layers.

I know there could be some trouble with padding, it tried this and this but it didn't help.

My conversion code looks like this:

from keras.layers import Conv2D
from torch import nn

import torch
import pandas as pd
import numpy as np

img = np.random.rand(1, 256, 256, 1)

## TF Init
conv_tf = Conv2D(
    64, 3, activation="relu", padding="same", kernel_initializer="he_normal", 
)

conv_tf(img)
conv_tf.bias = np.random.rand(64)

## PT Init + copy weights
conv_torch = nn.Conv2d(
    in_channels=1, out_channels=64, kernel_size=3, padding="same", bias=True
)
conv_torch.weight = nn.parameter.Parameter(
    torch.Tensor(conv_tf.weights[0].numpy().transpose(3, 2, 0, 1))
)
conv_torch.bias = nn.parameter.Parameter(torch.Tensor(conv_tf.bias))

conv_torch = nn.Sequential(
    conv_torch,
    nn.ReLU()
)

If I try to run a tensor through the models, the result is close but not the same (with huge differences between a few points)

pred_tf = conv_tf(img).numpy()
pred_pt = conv_torch(torch.Tensor(img).reshape(1, 1, 256, 256)).detach().numpy().reshape(pred_tf.shape)

pred_tf.mean()
#0.7202551

pred_pt.mean()
#0.7202549
TF - PT
count 4.1943e+06
mean -2.2992e-09
std 0.969716
min -3.85477
25% -0.641259
50% 0
75% 0.641266
max 3.8742

Any idea? Thanks

You suspect padding. This can be easily verified: compare pred_tf and pred_pt only on interior pixels: discard a band (1 pix wide, in your case) around the image.
If the interior pixels are identical - then it is a padding issue.

However, given your diff values, I suspect this is not the case.

Maybe there is an issue with reflecting/transposing the kernel weights between the two methods? Try working with non-square kernels: eg, 3x5 instead of 3x3.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM