简体   繁体   English

CNN 修剪问题:必须先修剪“权重”,然后才能删除修剪

[英]CNN Pruning Issue : 'weights' has to be pruned before pruning can be removed

I am new with Pruning using Pytorch.我是使用 Pytorch 进行修剪的新手。 I had this error when training my model:训练我的 model 时出现此错误:

ValueError: Parameter 'weight' of module Conv2d(1, 4, kernel_size=(7, 7), stride=(3, 3)) has to be pruned before pruning can be removed

Here is my whole ConvClass with the pruning method, I do not know how to solve it.这是我使用修剪方法的整个 ConvClass,我不知道如何解决。

I declared a Class where I put the init and the forward pass, and before that i implented a method "toggle_pruning" in order I want to have only 10 active neurons (just a try).我声明了一个 Class,我在其中放置了init和前向传递,在此之前我实现了一个方法“toggle_pruning”,以便我只想有 10 个活动神经元(只是尝试一下)。



# Construction du modéle
class ConvNet(torch.nn.Module):
    #Architecture
    def __init__(self, hidden=64, output=10): #hidden = 64 car c'est la couche intermédiaire  et output = 10 car nous avons 10 classes à classifier
        super(ConvNet, self).__init__()        
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3  ) #4 filtres de tailles 7*7 et strides 3*3   ==> output 9*9
        self.conv2 = torch.nn.Conv2d(4, 16, kernel_size= 3,padding=0,stride=2 ) #====> output 4*4
        self.fc1 = torch.nn.Linear(16*3*3, hidden) #sortie de la conv2d =256 et la sortie hidden=64
        self.fc2 = torch.nn.Linear(hidden, output)

    def toggle_pruning(self, enable):
        """Enables or removes pruning."""

        # Maximum number of active neurons (i.e. corresponding weight != 0)
        n_active = 10

        # Go through all the convolution layers
        for layer in (self.conv1, self.conv2):
            s = layer.weight.shape

            # Compute fan-in (number of inputs to a neuron)
            # and fan-out (number of neurons in the layer)
            st = [s[0], np.prod(s[1:])]

            # The number of input neurons (fan-in) is the product of
            # the kernel width x height x inChannels.
            if st[1] > n_active:
                if enable:
                    # This will create a forward hook to create a mask tensor that is multiplied
                    # with the weights during forward. The mask will contain 0s or 1s
                    prune.l1_unstructured(layer, "weight", (st[1] - n_active) * st[0])
                else:
                    # When disabling pruning, the mask is multiplied with the weights
                    # and the result is stored in the weights member
                    prune.remove(layer, "weight")


    #Forward Pass : décrire le fonctionement du modéle en avançant (ajouter les fonctions d'activations)
    def forward(self, x):
        x = self.conv1(x)
        # notre modéle utilise la "square" Activation
        x = x * x
        # Applatire tout en gardant la taille du batch
        x = self.conv2(x)
        x = x.view(-1, 16*3*3)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x
def train(model, train_loader, criterion, optimizer, n_epochs=10):
    # entrainer le modéle
    model.train()
    for epoch in range(1, n_epochs+1):
        start_time = time.time()
        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad() #vider notre gradient pour la back-propagation
            output = model(data)
            loss = criterion(output, target)
            loss.backward() #back-propagation
            optimizer.step() #optimizer à chaque passage
            train_loss += loss.item() #pour que la fonction loss reçoive une valeur à chaque iteration

        # Calculer la perte moyenne
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
        print('Training Time : %.2f' %(time.time() - start_time))
        # Finally, disable pruning (sets the pruned weights to 0)
        model.toggle_pruning(False)
    
    # Evaluation
    model.eval()
    return model


PATH = 'models/ConvNetModel.pth'

model = ConvNet().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model = train(model, train_loader, criterion, optimizer, 10)

The message was clear, you did't prune the model but you want to remove pruning.消息很清楚,您没有修剪 model 但您想删除修剪。
You need to call model.toggle_pruning(True) first then call model.toggle_pruning(False)您需要先调用 model.toggle_pruning(True) 然后调用 model.toggle_pruning(False)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM