简体   繁体   中英

Configuring a progress bar while training for Deep Learning

I have this tiny training function upcycled from a tutorial.

def train(epoch, tokenizer, model, device, loader, optimizer):
model.train()
with tqdm.tqdm(loader, unit="batch") as tepoch:
  for _,data in enumerate(loader, 0):
      y = data['target_ids'].to(device, dtype = torch.long)
      y_ids = y[:, :-1].contiguous()
      lm_labels = y[:, 1:].clone().detach()
      lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
      ids = data['source_ids'].to(device, dtype = torch.long)
      mask = data['source_mask'].to(device, dtype = torch.long)

      outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
      loss = outputs[0]

      tepoch.set_description(f"Epoch {epoch}")
      tepoch.set_postfix(loss=loss.item())
      
      if _%10 == 0:
          wandb.log({"Training Loss": loss.item()})

      if _%1000==0:
          print(f'Epoch: {epoch}, Loss:  {loss.item()}')
  
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      # xm.optimizer_step(optimizer)
      # xm.mark_step()

The function trains fine, the problem is that I can't seem to make the progress bar work correctly. I played around with it, but haven't found a configuration that correctly updates the loss and tells me how much time is left. Does anyone have any pointers on what I might be doing wrong? Thanks in advance!

In case anyone else has run in my same issue, thanks to the previous response I was able to configure the progress bar as I wanted with just a little tweak of what I was doing before:

def train(epoch, tokenizer, model, device, loader, optimizer):
  model.train()    
  for _,data in tqdm(enumerate(loader, 0), unit="batch", total=len(loader)):

everything stays the same, and now I have a progress bar showing percentage and loss. I prefer this solution because it allows me to keep the other logging functions I had without further changes.

preliminaries

Let's import in the conventional way:

from tqdm import tqdm

iterable

A tqdm progress bar is useful when used with an iterable, and you don't appear to be doing that. Or rather, you gave it an iterable, but then you didn't iterate on it there, you didn't really give tqdm a chance to repeatedly call next(...) .

generic example

We usually add a progress bar by replacing

for i in my_iterable:
    sleep(1)

with

for i in tqdm(my_iterable):
    sleep(1)

where the sleep could be any time consuming I/O or computation.

The progress bar has an opportunity to update each time through the loop.

your specific code

Roughly, you wrote:

with tqdm(loader) as tepoch:
    for _, data in enumerate(loader):

I recommend you simplify this, twice. Firstly, no need for enumerate:

    for data in loader:

Second and more importantly, remove the with :

for data in tqdm(loader):

This is the "plain vanilla" approach to using tqdm.


Now, I'll grant you, there's some fancy details farther down. You're attempting to report progress by setting description and postfix, and I imagine one might set additional attributes on tepoch . But it appears to be fancier than appropriate for your needs ATM, so I recommend deleting that to arrive at a simpler solution.


container

Tqdm works nicely with iterables, and even better with a certain kind of iterable: a container. Or more generally, with iterables that offer len(...) , which includes range(...) .

Tqdm defaults to trying to ask its argument for its length. If that's available then tqdm knows how close we are to the end, so rather than just reporting iterations per second it will also report the fraction completed and will estimate time to completion. If you offer a generator with no len(...) , but you know the total number of items it will generate, then it is definitely worth specifying it, eg tqdm(my_gen, total=50) . The resulting progress bar will be much more informative. An alternative is to wrap your generator in list(my_gen) , assuming that that takes a small fraction of the total time consumed by your processing loop.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM