简体   繁体   English

Wav2Vec pytorch 张量的元素 0 不需要 grad 并且没有 grad_fn

[英]Wav2Vec pytorch element 0 of tensors does not require grad and does not have a grad_fn

I am retraining a wav2vec model from hugging face for classification problem.我正在从拥抱脸重新训练 wav2vec model 以解决分类问题。 I have 5 classes and the input is a list of tensors [1,400].我有 5 个类,输入是张量列表 [1,400]。 Here is how I am getting the model这是我如何获得 model

num_labels = 5
model_name = "Zaid/wav2vec2-large-xlsr-53-arabic-egyptian"
model_config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)  ##needed for the visualizations
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name, config=model_config)

Here is the model updated settings这是 model 更新的设置

# Freeze the pre trained parameters
for param in model.parameters():
    param.requires_grad = False
criterion = nn.MSELoss().to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-6)

# Add three new layers at the end of the network
model.classifier = nn.Sequential(
    nn.Linear(768, 256),
    nn.Dropout(0.25),
    nn.ReLU(),
    nn.Linear(256, 64),
    nn.Dropout(0.25),
    nn.ReLU(),
    nn.Linear(64, 2),
    nn.Dropout(0.25),
    nn.Softmax(dim=1)
)

Then the training loop然后训练循环

print_every = 300

total_loss = 0
all_losses = []
model.train()
for epoch in range(2):
    print("Epoch number: ", epoch)
    for row in range(16918):
        Input = torch.tensor(trn_ivectors[row]).double()
        label = torch.tensor(trn_labels[row]).long().to(device)
        label = torch.unsqueeze(label,0).to(device)
        #print("Label", label.shape)
        Input = torch.unsqueeze(Input,1).to(device)
        #print(Input.shape)
        optimizer.zero_grad()
        
        #Input.requires_grad = True
        Input = F.softmax(Input[0], dim=-1)
        
        if label == 0:
            label = torch.tensor([1.0, 0.0]).float().to(device)
        elif label == 1:
            label = torch.tensor([0.0, 1.0]).float().to(device)

        # print(overall_output, label)

        loss = criterion(Input, label)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        if idx % print_every == 0 and idx > 0:
            average_loss = total_loss / print_every
            print("{}/{}. Average loss: {}".format(idx, len(train_data), average_loss))
            all_losses.append(average_loss)
            total_loss = 0

torch.save(model.state_dict(), "model_after_train.pt")

Unfortunately when I try to train the program it gives me the following error不幸的是,当我尝试训练程序时,它给了我以下错误

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Please I would appreciate if you could tell me how to fix this error.如果您能告诉我如何解决此错误,我将不胜感激。 I have been searching a lot on a way fixing it but didn't fix it我一直在寻找修复它的方法,但没有修复它

Thanks谢谢

Please try adding请尝试添加

requires_grad = True

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 张量的元素 0 不需要 grad 并且没有 grad_fn - element 0 of tensors does not require grad and does not have a grad_fn Pytorch RuntimeError:张量的元素 0 不需要 grad 并且没有 grad_fn - Pytorch RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 使用pytorch错误训练RNN:RuntimeError:张量的元素0不需要grad并且没有grad_fn - Error training RNN with pytorch : RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 运行时错误 - 张量的元素 0 不需要 grad 并且没有 grad_fn - Runtime Error - element 0 of tensors does not require grad and does not have a grad_fn 自定义损失 function 错误:张量没有 grad_fn - Custom loss function error: tensor does not have a grad_fn 在 PyTorch 中,grad_fn 属性究竟存储了什么以及如何使用? - In PyTorch, what exactly does the grad_fn attribute store and how is it used? Pytorch (1.0) 中类似操作的不同`grad_fn` - Different `grad_fn` for similar looking operations in Pytorch (1.0) required_grad 在 PyTorch 中做什么? (不是 requires_grad) - What does required_grad do in PyTorch? (Not requires_grad) torch.no_grad 如何在没有 grad 的情况下生成内部张量? - How does torch.no_grad work that makes inner tensors without grad? 自定义 torch.nn.Module 不学习,即使 grad_fn=MmBackward - Custom torch.nn.Module not learning, even though grad_fn=MmBackward
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM