簡體   English   中英

如何禁用 Teacher Forcing RNN model

[英]How to disable Teacher Forcing RNN model

我有以下 Teacher forcing RNN model,我隱式地將整個輸入序列 (inputs = ids[:, i:i+seq_length] 一次傳遞給 model。我應該修改什么以禁用教師強制訓練並獲得原始model。

ids = corpus.get_data('data/train.txt', batch_size)


model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Truncated backpropagation
def detach(states):
    return [state.detach() for state in states] 


# Train the model
for epoch in range(num_epochs):
    # Set initial hidden and cell states
    states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
              torch.zeros(num_layers, batch_size, hidden_size).to(device))
    
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        
        # Forward pass
        states = detach(states)
        outputs, states = model(inputs, states)
        loss = criterion(outputs, targets.reshape(-1))
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        step = (i+1) // seq_length
        if step % 100 == 0:
            print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

我試圖以不同的方式傳遞輸入和目標,但沒有任何效果。 我有點困惑原始 model 的輸入和目標應該是什么。

要在 model 中禁用教師強制,您需要修改生成輸入和目標序列的代碼。 目前,輸入序列是通過從 ids 張量中獲取連續的 seq_length 標記塊來構建的,從 position i 開始,到 position i + seq_length 結束。 目標序列是通過從 ids 張量中獲取連續的 seq_length 標記塊來構建的,從 position i + 1 開始到 position (i + 1) + seq_length。

要禁用教師強制,您需要以不同方式構建輸入和目標序列。 您不應使用固定長度的標記序列,而應使用單個標記作為輸入和相應的目標標記。 這意味着您需要一次遍歷 ids 張量一個標記,而不是使用固定長度的標記塊。 您可以通過以下方式修改代碼來執行此操作:

#Train the model
for epoch in range(num_epochs):
    # Set initial hidden and cell states
    states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
              torch.zeros(num_layers, batch_size, hidden_size).to(device))
    
    # Loop through the tokens in the input sequence one at a time
    for i in range(0, ids.size(1)):
        # Get the current input and target tokens
        input = ids[:, i].to(device)
        target = ids[:, i+1].to(device)
        
        # Forward pass
        states = detach(states)
        outputs, states = model(input, states)
        loss = criterion(outputs, target)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        if i % 100 == 0:

`

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM