繁体   English   中英

将 Model 从 3 通道 (RGB) 重新训练到 4 通道 (RGBA),我可以使用 3 通道权重吗?

[英]Retraining a Model from 3 Channels (RGB) to 4 Channels (RGBA), can I use the 3 channel weights?

我需要将 model 从 RGB 扩展为 RGBA。 我可以处理 model 上的代码重写,但不是从头开始重新训练整个 model,我很想从它的 3 通道权重 + 零开始。

有没有一种简单的方法可以将手电筒保存的 3 个通道权重更改为 4 个?

是的,你可以做一点“模型手术”。 假设 model 的输入仅由卷积层直接处理,那么您可以将该卷积层替换为另一个将in_channels设置为4的卷积层。 然后,您可以将权重设置为零并从原始 conv 层复制旧的权重(和偏差,如果适用)。

例如,假设我们有一个看起来像这样的简单 model

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
        self.linear = nn.Linear(125, 1)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return self.linear(x.flatten(start_dim=1))

model = SimpleModel()

假设此时训练了 model,我们可以进行如下手术

y_rgb = torch.randn(1, 3, 5, 5)

# get performance on initial z_rgb
z_rgb = model(y_rgb)

# perform model surgery
with torch.no_grad():
    new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
    new_conv1.weight.zero_()
    new_conv1.weight[:,:3,...]=model.conv1.weight
    new_conv1.bias.copy_(model.conv1.bias)
    model.conv1 = new_conv1

# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)

# get results on rgba model
z_rgba = model(y_rgba)

# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)

# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')

这应该给你一个零或非常接近于零的错误。

当然,如果你的第一个 conv 层中没有bias项,那么你不需要复制它。

Assuming you've saved the new state dictionary, then you will probably want to update the model class definition so that your input convolution layer takes 4 channel input instead of 3. Then next time you can directly load the new state dictionary without additional steps.


现在没有必要直接对 model 进行手术。 虽然我更喜欢它,因为我发现它更容易验证正确性。

假设您为 RGB model 保存了 state 字典,您也可以直接修改 state 字典。

# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
    old_weight.shape[0],
    old_weight.shape[1]+1,
    old_weight.shape[2],
    old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM