簡體   English   中英

RuntimeError: Error(s) in loading state_dict for Generator: 使用 Pytorch 的權重和偏差大小不匹配

[英]RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for weights and biases using Pytorch

我正在訓練一個 3D-GAN 來生成 MRI 體積。 我定義我的模型如下:

###### Definition of the generator ######

class Generator(nn.Module):
  def __init__(self, ngpu):
    #super() makes Generator a subclass of nn.Module, so that it inherites all the methods of nn.Module
    super(Generator, self).__init__()
    self.ngpu = ngpu
    #we can use Sequential() since the output of one layer is the input of the next one
    self.main = nn.Sequential(   
        # input is latent vector z, going into a convolution 
        nn.ConvTranspose3d(nz, ngf * 8, 4, stride=2, padding=0, bias=True), # try to put kernel = (batch_size,4,4,4,512)
        nn.BatchNorm3d(ngf * 8),
        nn.ReLU(True), #True means that it does the operation inplace, default is False

        nn.ConvTranspose3d(ngf * 8, ngf * 4, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,8,8,8,256)
        nn.BatchNorm3d(ngf * 4),
        nn.ReLU(True),

        nn.ConvTranspose3d(ngf * 4, ngf * 2, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,16,16,16,128)
        nn.BatchNorm3d(ngf * 2),
        nn.ReLU(True),

        nn.ConvTranspose3d( ngf * 2, ngf, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,32,32,32,64)
        nn.BatchNorm3d(ngf),
        nn.ReLU(True),

        nn.ConvTranspose3d(ngf, nc, 4, stride=2, padding=1, bias=True), # try to put kernel = (batch_size,64,64,64,1)
        nn.Sigmoid()

        )

  def forward(self, x):
    return self.main(x)


###### Definition of the Discriminator ######

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv3d(nc, ndf, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm3d(ndf),
            nn.LeakyReLU(leak_value, inplace=True),

            nn.Conv3d(ndf, ndf * 2, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm3d(ndf * 2),
            nn.LeakyReLU(leak_value, inplace=True),

            nn.Conv3d(ndf * 2, ndf * 4, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm3d(ndf * 4),
            nn.LeakyReLU(leak_value, inplace=True),

            nn.Conv3d(ndf * 4, ndf * 8, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm3d(ndf * 8),
            nn.LeakyReLU(leak_value, inplace=True),

            nn.Conv3d(ndf * 8, nc, 4, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

然后我訓練模型並保存它。 加載模型進行評估和測試時,出現以下錯誤:

RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for main.0.weight: copying a param with shape torch.Size([64, 1, 4, 4, 4]) from checkpoint, the shape in current model是torch.Size([200, 512, 4, 4, 4])。 main.0.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([64]) 的參數,當前模型中的形狀為 torch.Size([512])。 main.1.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([64]) 的參數,當前模型中的形狀為 torch.Size([512])。 main.1.running_mean 的大小不匹配:從檢查點復制形狀為 torch.Size([64]) 的參數,當前模型中的形狀為 torch.Size([512])。 main.1.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([64]) 的參數,當前模型中的形狀為 torch.Size([512])。 main.1.running_var 的大小不匹配:從檢查點復制形狀為 torch.Size([64]) 的參數,當前模型中的形狀為 torch.Size([512])。 main.3.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([128, 64, 4, 4, 4]) 的參數,當前模型中的形狀為 torch.Size([512, 256, 4, 4, 4])。 main.3.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([128]) 的參數,當前模型中的形狀為 torch.Size([256])。 main.4.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([128]) 的參數,當前模型中的形狀為 torch.Size([256])。 main.4.running_mean 的大小不匹配:從檢查點復制形狀為 torch.Size([128]) 的參數,當前模型中的形狀為 torch.Size([256])。 main.4.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([128]) 的參數,當前模型中的形狀為 torch.Size([256])。 main.4.running_var 的尺寸不匹配:從檢查點復制形狀為 torch.Size([128]) 的參數,當前模型中的形狀為 torch.Size([256])。 main.6.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([256]) 的參數,當前模型中的形狀為 torch.Size([128])。 main.7.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([256]) 的參數,當前模型中的形狀為 torch.Size([128])。 main.7.running_mean 的大小不匹配:從檢查點復制形狀為 torch.Size([256]) 的參數,當前模型中的形狀為 torch.Size([128])。 main.7.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([256]) 的參數,當前模型中的形狀為 torch.Size([128])。 main.7.running_var 的尺寸不匹配:從檢查點復制形狀為 torch.Size([256]) 的參數,當前模型中的形狀為 torch.Size([128])。 main.9.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512, 256, 4, 4, 4]) 的參數,當前模型中的形狀為 torch.Size([128, 64, 4, 4, 4])。 main.9.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512]) 的參數,當前模型中的形狀為 torch.Size([64])。 main.10.weight 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512]) 的參數,當前模型中的形狀為 torch.Size([64])。 main.10.running_mean 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512]) 的參數,當前模型中的形狀為 torch.Size([64])。 main.10.bias 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512]) 的參數,當前模型中的形狀為 torch.Size([64])。 main.10.running_var 的尺寸不匹配:從檢查點復制形狀為 torch.Size([512]) 的參數,當前模型中的形狀為 torch.Size([64])。 main.12.weight 的大小不匹配:從檢查點復制形狀為 torch.Size([1, 512, 4, 4, 4]) 的參數,當前模型中的形狀為 torch.Size([64, 1, 4, 4, 4])。

我做錯了什么?

提前致謝!

您加載的模型和目標模型不相同,因此會引發錯誤以告知尺寸、層數不匹配,請再次檢查您的代碼,或者您保存的模型可能無法正確保存

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM