简体   繁体   中英

gradient accumulation stopping at 50%

训练停止在 50%

the original batch_size = 16, but I wanted to give accumulation = 2 so that I have a similar effect as when I used batch_size = 32.

The original training time lasted an hour, so I expected 2 hour training time with the gradient accumulation.

But the training ends at 50%, lasting an hour even with the gradient accumulation.

I don't know why it's stopping.. below is my code for training

def train_runner(model, train_dataset, valid_dataset, batch_size, num_train_epochs, learning_rate): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
model.train()
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size)
valid_dataloader = DataLoader(dataset = valid_dataset, batch_size = batch_size)

lowest_total_valid_loss = 9999.
step = 0
global_total_step = len(train_dataloader) * num_train_epochs
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0)
print("TRAIN START")
with tqdm(total=global_total_step, unit='step') as t:
    total = 0
    total_loss = 0
    for epoch in range(num_train_epochs):
        for iteration,batch in enumerate(train_dataloader):
            #optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)
            outputs = model(input_ids,
                         attention_mask=attention_mask,
                         start_positions=start_positions,
                         end_positions=end_positions)
            loss = outputs.loss
            (loss / ACCUMULATION).backward()

            step += 1
            if step % ACCUMULATION:
                continue

            clip_grad_norm_(model.parameters(), max_norm=1.)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            batch_loss = loss.item() * len(input_ids)
            total += len(input_ids)
            total_loss += batch_loss / ACCUMULATION
            global_total_step += 1
            t.set_postfix(loss="{:.6f}".format(total_loss / total), batch_loss="{:.6f}".format(batch_loss))
            t.update(1)
            
            del input_ids
            del attention_mask
            del start_positions
            del end_positions
            del outputs
            del loss

            ## validation ##
            if iteration != 0 and iteration % int(len(train_dataloader) / 10) == 0:
                total_valid_loss = 0
                for batch_val in valid_dataloader:
                    model.eval()
                    optimizer.zero_grad()

                    input_ids = batch_val['input_ids'].to(device)
                    attention_mask = batch_val['attention_mask'].to(device)
                    start_positions = batch_val['start_positions'].to(device)
                    end_positions = batch_val['end_positions'].to(device)
            
                    with torch.no_grad():
                        outputs = model(input_ids,
                                attention_mask=attention_mask,
                                start_positions=start_positions,
                                end_positions=end_positions)
                        loss = outputs.loss
                        total_valid_loss += loss.item()
                
                if total_valid_loss < lowest_total_valid_loss:
                    print(f"lowest_total_valid_loss: {total_valid_loss} epoch : {epoch} iteration : {iteration}")
                    torch.save(model.state_dict(),'./output_model_best')
                    lowest_total_valid_loss = total_valid_loss
            ## validation ##

#model.save_pretrained("./klue_output_model")
print("TRAIN END")
for iteration,batch in enumerate(train_dataloader):    
    if step % ACCUMULATION:
       t.update(1) # add one update here as well.
       continue 
    ...
    t.update(1)

Half of the time you do not update the tqdm counter or set its value too high during initialization. So it can't go higher than 50%.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM