[英]Pass pretrained weights in CNN Pytorch to a CNN in Tensorflow
I have trained this network in Pytorch for 224x224 size images and 4 classes.我已经在 Pytorch 中为 224x224 大小的图像和 4 个类训练了这个网络。
class CustomConvNet(nn.Module):
def __init__(self, num_classes):
super(CustomConvNet, self).__init__()
self.layer1 = self.conv_module(3, 64)
self.layer2 = self.conv_module(64, 128)
self.layer3 = self.conv_module(128, 256)
self.layer4 = self.conv_module(256, 256)
self.layer5 = self.conv_module(256, 512)
self.gap = self.global_avg_pool(512, num_classes)
#self.linear = nn.Linear(512, num_classes)
#self.relu = nn.ReLU()
#self.softmax = nn.Softmax()
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
out = self.gap(out)
out = out.view(-1, 4)
#out = self.linear(out)
return out
def conv_module(self, in_num, out_num):
return nn.Sequential(
nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=None))
def global_avg_pool(self, in_num, out_num):
return nn.Sequential(
nn.Conv2d(in_num, out_num, kernel_size=3, stride=1, padding=1),
#nn.BatchNorm2d(out_num),
#nn.LeakyReLU(),
nn.ReLU(),
nn.Softmax(),
nn.AdaptiveAvgPool2d((1, 1)))
I got the weights from the first Conv2D and it's size torch.Size([64, 3, 3, 3])
我从第一个 Conv2D 得到权重,它的大小
torch.Size([64, 3, 3, 3])
I have saved it as:我已将其保存为:
weightsCNN = net.layer1[0].weight.data
np.save('CNNweights.npy', weightsCNN)
This is my model I built in Tensorflow.这是我在 Tensorflow 中构建的 model。 I would like to pass those weights I saved from the Pytorch model into this Tensorflow CNN.
我想将从 Pytorch model 中保存的权重传递给这个 Tensorflow CNN。
model = models.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(4, activation='softmax'))
print(model.summary())
adam = optimizers.Adam(learning_rate=0.0001, amsgrad=False)
model.compile(loss='categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'])
nb_train_samples = 6596
nb_validation_samples = 1290
epochs = 10
batch_size = 256
history = model.fit_generator(
train_generator,
steps_per_epoch=np.ceil(nb_train_samples/batch_size),
epochs=epochs,
validation_data=validation_generator,
validation_steps=np.ceil(nb_validation_samples / batch_size)
)
How should I actually do that?我应该怎么做? What shape of weights does Tensorflow require?
Tensorflow 需要什么形状的砝码? Thanks!
谢谢!
You can check shapes of all weights of all keras
layers quite simply:您可以非常简单地检查所有
keras
层的所有权重的形状:
for layer in model.layers:
print([tensor.shape for tensor in layer.get_weights()])
This would give you shapes of all weights (including biases), so you can prepare loaded numpy
weights accordingly.这将为您提供所有权重的形状(包括偏差),因此您可以相应地准备加载
numpy
权重。
To set them, do something similar:要设置它们,请执行类似的操作:
for torch_weight, layer in zip(model.layers, torch_weights):
layer.set_weights(torch_weight)
where torch_weights
should be a list containing lists of np.array
which you would have to load.其中
torch_weights
应该是一个列表,其中包含您必须加载的np.array
列表。
Usually each element of torch_weights
would contain one np.array
for weights and one for bias.通常,
torch_weights
的每个元素将包含一个用于权重的np.array
和一个用于偏差的。
Remember shapes received from print have to be exactly the same as the ones you put in set_weights
.请记住,从 print 收到的形状必须与您放入
set_weights
的形状完全相同。
See documentation for more info.有关更多信息,请参阅文档。
BTW.顺便提一句。 Exact shapes are dependent on layers and operations performed by model, you may have to transpose some arrays sometimes to "fit them in".
确切的形状取决于 model 执行的层和操作,有时您可能必须转置一些 arrays 以“适应它们”。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.