简体   繁体   中英

Pass pretrained weights in CNN Pytorch to a CNN in Tensorflow

I have trained this network in Pytorch for 224x224 size images and 4 classes.

class CustomConvNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomConvNet, self).__init__()

        self.layer1 = self.conv_module(3, 64)
        self.layer2 = self.conv_module(64, 128)
        self.layer3 = self.conv_module(128, 256)
        self.layer4 = self.conv_module(256, 256)
        self.layer5 = self.conv_module(256, 512)
        self.gap = self.global_avg_pool(512, num_classes)
        #self.linear = nn.Linear(512, num_classes)
        #self.relu = nn.ReLU()
        #self.softmax = nn.Softmax()

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.gap(out)
        out = out.view(-1, 4)
        #out = self.linear(out)

        return out

    def conv_module(self, in_num, out_num):
        return nn.Sequential(
            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=None))

    def global_avg_pool(self, in_num, out_num):
        return nn.Sequential(
            nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(out_num),
            #nn.LeakyReLU(),

            nn.ReLU(),
            nn.Softmax(),
            nn.AdaptiveAvgPool2d((1, 1)))

I got the weights from the first Conv2D and it's size torch.Size([64, 3, 3, 3])

I have saved it as:

weightsCNN = net.layer1[0].weight.data
np.save('CNNweights.npy', weightsCNN)

This is my model I built in Tensorflow. I would like to pass those weights I saved from the Pytorch model into this Tensorflow CNN.

    model = models.Sequential()
    model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(128, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(256, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(256, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(512, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))

    model.add(layers.Conv2D(512, (3, 3), activation='relu'))

    model.add(layers.GlobalAveragePooling2D())
    model.add(layers.Dense(4, activation='softmax'))
    print(model.summary())


    adam = optimizers.Adam(learning_rate=0.0001, amsgrad=False)
    model.compile(loss='categorical_crossentropy',
                  optimizer=adam,
                  metrics=['accuracy'])


    nb_train_samples = 6596
    nb_validation_samples = 1290
    epochs = 10
    batch_size = 256


    history = model.fit_generator(
        train_generator,
        steps_per_epoch=np.ceil(nb_train_samples/batch_size),
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=np.ceil(nb_validation_samples / batch_size)
        )

How should I actually do that? What shape of weights does Tensorflow require? Thanks!

You can check shapes of all weights of all keras layers quite simply:

for layer in model.layers:
    print([tensor.shape for tensor in layer.get_weights()])

This would give you shapes of all weights (including biases), so you can prepare loaded numpy weights accordingly.

To set them, do something similar:

for torch_weight, layer in zip(model.layers, torch_weights):
    layer.set_weights(torch_weight)

where torch_weights should be a list containing lists of np.array which you would have to load.

Usually each element of torch_weights would contain one np.array for weights and one for bias.

Remember shapes received from print have to be exactly the same as the ones you put in set_weights .

See documentation for more info.

BTW. Exact shapes are dependent on layers and operations performed by model, you may have to transpose some arrays sometimes to "fit them in".

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM