简体   繁体   中英

Why does requires_grad turns from true to false when doing torch.nn.conv2d operation?

I have Unet network which takes in MRI images of the brain, where the goal is to segment white substance in the brain. The images has the shape 256x256x183 (reshaped to 183x256x256) (FLAIR and T1 images). The problem I am having is that before sending the input to the Unet network, I have requires_grad=True on my pytorch tensor, but after one torch.nn.conv2d operation the requires_grad=False. This is a huge problem since the gradient will not update and learn.

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):
    
    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):

        print(x.requires_grad) #<---- here it is true
        enc1 = self.encoder1(x)#<---- where the problem happens
       
        print(enc1.requires_grad) #<---- here it is false
        enc2 = self.encoder2(self.pool1(enc1))
        print(enc2.requires_grad)
        enc3 = self.encoder3(self.pool2(enc2))
        print(enc3.requires_grad)
        enc4 = self.encoder4(self.pool3(enc3))
        print(enc4.requires_grad)

        bottleneck = self.bottleneck(self.pool4(enc4))
        print(bottleneck.requires_grad)

        dec4 = self.upconv4(bottleneck)
        print(dec4.requires_grad)
        dec4 = torch.cat((dec4, enc4), dim=1)
        print(dec4.requires_grad)
        dec4 = self.decoder4(dec4)
        print(dec4.requires_grad)
        dec3 = self.upconv3(dec4)
        print(dec3.requires_grad)
        dec3 = torch.cat((dec3, enc3), dim=1)
        print(dec3.requires_grad)
        dec3 = self.decoder3(dec3)
        print(dec3.requires_grad)
        dec2 = self.upconv2(dec3)
        print(dec2.requires_grad)
        dec2 = torch.cat((dec2, enc2), dim=1)
        print(dec2.requires_grad)
        dec2 = self.decoder2(dec2)
        print(dec2.requires_grad)
        dec1 = self.upconv1(dec2)
        print(dec1.requires_grad)
        dec1 = torch.cat((dec1, enc1), dim=1)
        print(dec1.requires_grad)
        dec1 = self.decoder1(dec1)
        print(dec1.requires_grad)
        print("going out")
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

Edit: This is the training code

class run_network:
def __init__(self, eta, epoch, batch_size, train_file_path, validation_file_path, shuffle_after_epoch = True):
    self.eta = eta
    self.epoch = epoch
    self.batch_size = batch_size
    self.train_file_path = train_file_path
    self.validation_file_path = validation_file_path
    self.shuffle_after_epoch = shuffle_after_epoch

def __call__(self, is_train = False):
    
    device = torch.device("cpu" if not torch.cuda.is_available() else torch.cuda())
    unet = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)
    unet.to(device)
    unet = unet.double()
    
    

    
    
    optimizer = optim.Adam(unet.parameters(), lr=self.eta)
    dsc_loss = DiceLoss()
    

    Load_training   = NiftiLoader(self.train_file_path)
    Load_validation = NiftiLoader(self.validation_file_path)
    
    mean_flair, mean_t1, std_flair, std_t1 = Load_training.average_mean_and_std(20, 79,99)

    total_mean = [mean_flair, mean_t1]
    total_std = [std_flair, std_t1]

    loss_train = []
    loss_validation = []


    

    for current_epoch in tqdm(range(self.epoch)):
        for phase in ["train", "validation"]:
            
            
            if phase == "train":
                mini_batch = Load_training.create_batch(self.batch_size, self.shuffle_after_epoch)
                unet.train()
                print("her22")

            if phase == "validation":
                print("her")
                mini_batch = Load_validation.create_batch(self.batch_size, self.shuffle_after_epoch)
                unet.eval()
            
            
            dim1, dim2, dim3 = mini_batch.shape
        
            for iteration in range(1):
                if phase == "train":
                    current_batch = Load_training.Load_Image_batch(mini_batch, iteration)
                    image_batch = Load_training.image_zero_mean_normalizer(current_batch)
                if phase == "validation":
                    current_batch = Load_validation.Load_Image_batch(mini_batch, iteration)
                    image_batch = Load_training.image_zero_mean_normalizer(current_batch, False, mean_list, std_list)


                image_dim0, image_dim1, image_dim2, image_dim3, image_dim4 = image_batch.shape
                image_batch = image_batch.reshape((
                                                    image_dim0, 
                                                    image_dim1*image_dim2, 
                                                    image_dim3, 
                                                    image_dim4
                                                    ))

                
                image_batch = np.swapaxes(image_batch, 0,1)
                image_batch = torch.as_tensor(image_batch)#.requires_grad_(True) #, requires_grad=True)
                image_batch = image_batch.to(device)
                print(image_batch.requires_grad)
                optimizer.zero_grad()
                
            
                with torch.set_grad_enabled(is_train == "train"):
                    for j in range(0, 10, 1): 
                        # [183*5, 3, 256, 256] -> [12, 3, 256, 256]  
                        # ANTALL ITERASJONER: (183*5/12) -> en chunk  
                            
                        input_image = image_batch[j:j+2,0:3,:,:]
                        print(input_image.requires_grad)
                        print("går inn")
                        y_predicted = unet(input_image)
                    
                        print(y_predicted.requires_grad)
                        print(image_batch[j:j+2,3,:,:].requires_grad)
                        loss = dsc_loss(y_predicted.squeeze(1), image_batch[j:j+2,3,:,:])
                        print(loss.requires_grad)
                       
                        if phase == "train":
                            loss_train.append(loss.item())
                            
                            loss.backward()
                            print(loss.item())
                            exit()
                            optimizer.step()
                            print(loss.item())
                            exit()
                        if phase == "validation":
                            loss_validation.append(loss.item())

Number of iteration and print statement are for experimenting what the cause could be.

It works fine to me.

'''
# I changed your code a little bit to catch up the problem.
def forward(self, x):

        print("encoder1", x.requires_grad) #<---- here it is true
        enc1 = self.encoder1(x)#<---- where the problem happens
       
        print("encoder2", enc1.requires_grad) #<---- here it is false
'''
a = torch.randn(32, 3, 255, 255, requires_grad=True)
# a.requires_grads = True
print(a)
UNet()(a)

# This is the result:
encoder1 True
encoder2 True
True
True
True
True
True

Can you show me your training source? I guess it's the problem. And why do you need to update the input data?

The training code is fine and the input doesn't need a gradient at all, if you just want to train and update the weights.

The real problem is this line here

 with torch.set_grad_enabled(is_train == "train"):

So you want to disable the gradients if you are not training. The thing is is_train is a bool (judging form this: def __call__(self, is_train=False): ), so the comparisons will be always false and no gradients will bet set. Just change it to

with torch.set_grad_enabled(is_train):

and you will be fine.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM