简体   繁体   English

在训练深度学习时配置进度条

[英]Configuring a progress bar while training for Deep Learning

I have this tiny training function upcycled from a tutorial.我从教程中升级了这个小型培训 function。

def train(epoch, tokenizer, model, device, loader, optimizer):
model.train()
with tqdm.tqdm(loader, unit="batch") as tepoch:
  for _,data in enumerate(loader, 0):
      y = data['target_ids'].to(device, dtype = torch.long)
      y_ids = y[:, :-1].contiguous()
      lm_labels = y[:, 1:].clone().detach()
      lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
      ids = data['source_ids'].to(device, dtype = torch.long)
      mask = data['source_mask'].to(device, dtype = torch.long)

      outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
      loss = outputs[0]

      tepoch.set_description(f"Epoch {epoch}")
      tepoch.set_postfix(loss=loss.item())
      
      if _%10 == 0:
          wandb.log({"Training Loss": loss.item()})

      if _%1000==0:
          print(f'Epoch: {epoch}, Loss:  {loss.item()}')
  
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      # xm.optimizer_step(optimizer)
      # xm.mark_step()

The function trains fine, the problem is that I can't seem to make the progress bar work correctly. function 训练良好,问题是我似乎无法使进度条正常工作。 I played around with it, but haven't found a configuration that correctly updates the loss and tells me how much time is left.我玩过它,但没有找到正确更新损失并告诉我还剩多少时间的配置。 Does anyone have any pointers on what I might be doing wrong?有没有人对我可能做错了什么有任何指示? Thanks in advance!提前致谢!

In case anyone else has run in my same issue, thanks to the previous response I was able to configure the progress bar as I wanted with just a little tweak of what I was doing before:万一其他人遇到了我的同一个问题,多亏了之前的回复,我能够按照我的意愿配置进度条,只需稍微调整一下我之前所做的事情:

def train(epoch, tokenizer, model, device, loader, optimizer):
  model.train()    
  for _,data in tqdm(enumerate(loader, 0), unit="batch", total=len(loader)):

everything stays the same, and now I have a progress bar showing percentage and loss.一切都保持不变,现在我有一个显示百分比和损失的进度条。 I prefer this solution because it allows me to keep the other logging functions I had without further changes.我更喜欢这个解决方案,因为它允许我保留我拥有的其他日志记录功能而无需进一步更改。

preliminaries预赛

Let's import in the conventional way:让我们以常规方式导入:

from tqdm import tqdm

iterable可迭代的

A tqdm progress bar is useful when used with an iterable, and you don't appear to be doing that.与可迭代对象一起使用时,tqdm 进度条很有用,而您似乎没有这样做。 Or rather, you gave it an iterable, but then you didn't iterate on it there, you didn't really give tqdm a chance to repeatedly call next(...) .或者更确切地说,你给了它一个可迭代的,但是你没有在那里迭代它,你并没有真正给 tqdm 一个重复调用next(...)的机会。

generic example通用示例

We usually add a progress bar by replacing我们通常通过替换来添加进度条

for i in my_iterable:
    sleep(1)

with

for i in tqdm(my_iterable):
    sleep(1)

where the sleep could be any time consuming I/O or computation. sleep可能是任何耗时的 I/O 或计算。

The progress bar has an opportunity to update each time through the loop.进度条每次循环都有机会更新。

your specific code你的具体代码

Roughly, you wrote:粗略地说,你写道:

with tqdm(loader) as tepoch:
    for _, data in enumerate(loader):

I recommend you simplify this, twice.我建议您将其简化两次。 Firstly, no need for enumerate:首先,不需要枚举:

    for data in loader:

Second and more importantly, remove the with :其次,更重要的是,删除with

for data in tqdm(loader):

This is the "plain vanilla" approach to using tqdm.这是使用 tqdm 的“普通”方法。


Now, I'll grant you, there's some fancy details farther down.现在,我承认,后面还有一些花哨的细节。 You're attempting to report progress by setting description and postfix, and I imagine one might set additional attributes on tepoch .您试图通过设置描述和后缀来报告进度,我想可能会在tepoch上设置其他属性。 But it appears to be fancier than appropriate for your needs ATM, so I recommend deleting that to arrive at a simpler solution.但它似乎比您需要的 ATM 更合适,所以我建议删除它以获得更简单的解决方案。


container容器

Tqdm works nicely with iterables, and even better with a certain kind of iterable: a container. Tqdm 可以很好地与可迭代对象一起工作,甚至更适合某种可迭代对象:容器。 Or more generally, with iterables that offer len(...) , which includes range(...) .或者更一般地说,使用提供len(...)的迭代器,其中包括range(...)

Tqdm defaults to trying to ask its argument for its length. Tqdm 默认尝试询问其参数的长度。 If that's available then tqdm knows how close we are to the end, so rather than just reporting iterations per second it will also report the fraction completed and will estimate time to completion.如果可用,则 tqdm 知道我们离结束有多近,因此它不仅会报告每秒的迭代次数,还会报告已完成的分数并估计完成时间。 If you offer a generator with no len(...) , but you know the total number of items it will generate, then it is definitely worth specifying it, eg tqdm(my_gen, total=50) .如果您提供一个没有len(...)的生成器,但您知道它将生成的项目总数,那么绝对值得指定它,例如tqdm(my_gen, total=50) The resulting progress bar will be much more informative.生成的进度条将提供更多信息。 An alternative is to wrap your generator in list(my_gen) , assuming that that takes a small fraction of the total time consumed by your processing loop.另一种方法是将生成器包装在list(my_gen)中,假设这仅占处理循环消耗的总时间的一小部分。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM