简体   繁体   English

将 Model 从 3 通道 (RGB) 重新训练到 4 通道 (RGBA),我可以使用 3 通道权重吗?

[英]Retraining a Model from 3 Channels (RGB) to 4 Channels (RGBA), can I use the 3 channel weights?

I need to expand a model from RGB to RGBA.我需要将 model 从 RGB 扩展为 RGBA。 I can handle the code rewrite on the model, but instead of retraining the entire model from scratch, I would love to start it off with it's 3 channel weights + zeros.我可以处理 model 上的代码重写,但不是从头开始重新训练整个 model,我很想从它的 3 通道权重 + 零开始。

Is there an easy way to change torch's save of 3 channel weights into 4?有没有一种简单的方法可以将手电筒保存的 3 个通道权重更改为 4 个?

Yes, you can do a little bit of "model surgery".是的,你可以做一点“模型手术”。 Assuming the input to the model is only processed directly by a convolutional layer then you can just replace that conv layer with another that has in_channels set to 4 .假设 model 的输入仅由卷积层直接处理,那么您可以将该卷积层替换为另一个将in_channels设置为4的卷积层。 Then you can set weights to zero and copy over the old weights (and biases if applicable) from the original conv layer.然后,您可以将权重设置为零并从原始 conv 层复制旧的权重(和偏差,如果适用)。

For example, say we had a simple model that looked like this例如,假设我们有一个看起来像这样的简单 model

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
        self.linear = nn.Linear(125, 1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return self.linear(x.flatten(start_dim=1))

model = SimpleModel()

Supposing that the model is trained at this point, we could perform the surgery as follows假设此时训练了 model,我们可以进行如下手术

y_rgb = torch.randn(1, 3, 5, 5)

# get performance on initial z_rgb
z_rgb = model(y_rgb)

# perform model surgery
with torch.no_grad():
    new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
    new_conv1.weight.zero_()
    new_conv1.weight[:,:3,...]=model.conv1.weight
    new_conv1.bias.copy_(model.conv1.bias)
    model.conv1 = new_conv1

# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)

# get results on rgba model
z_rgba = model(y_rgba)

# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)

# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')

which should give you an error of zero or very close to zero.这应该给你一个零或非常接近于零的错误。

Of course if you don't have a bias term in your first conv layer then you don't need to copy that over.当然,如果你的第一个 conv 层中没有bias项,那么你不需要复制它。

Assuming you've saved the new state dictionary, then you will probably want to update the model class definition so that your input convolution layer takes 4 channel input instead of 3. Then next time you can directly load the new state dictionary without additional steps. Assuming you've saved the new state dictionary, then you will probably want to update the model class definition so that your input convolution layer takes 4 channel input instead of 3. Then next time you can directly load the new state dictionary without additional steps.


Now it's not strictly necessary to do the surgery on the model directly.现在没有必要直接对 model 进行手术。 Though I tend to prefer it as I find it easier to verify correctness.虽然我更喜欢它,因为我发现它更容易验证正确性。

Assuming you saved off the state dictionary for the RGB model, you could also just directly modify the state dictionary.假设您为 RGB model 保存了 state 字典,您也可以直接修改 state 字典。

# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
    old_weight.shape[0],
    old_weight.shape[1]+1,
    old_weight.shape[2],
    old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM