简体   繁体   中英

How to train a network with two datasets and two output heads?

I am tying to train two datasets at the same time to get good results on both datasets.

data_loader_iterator = iter(data_loader_second)
for batch_idx, (image1, label1) in enumerate(data_loader):
    image1 = image1.to(args.local_rank)
    label1 = label1.to(args.local_rank)
    label1 = label1.squeeze()
    try:
        image2, label2 = next(data_loader_iterator)
    except StopIteration:
        data_loader_iterator = iter(data_loader_second)
        image2, label2 = next(data_loader_iterator)
    image2 = image2.to(args.local_rank)
    label2 = label2.to(args.local_rank)
    label2 = label2.squeeze()
    
    embedding1 = backbone.forward(image1)
    embedding2 = backbone.forward(image2)
    output1 = head1.forward(embedding1, label1)
    output2 = head2.forward(embedding2, label2)
    loss1 = criterion(output1, label1)
    loss2 = criterion(output2, label2)
    loss = loss1 + loss2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    ...

head1 = HeadFactory(args.head_type, args.head_conf_file, 751577).get_head()
head2 = HeadFactory(args.head_type, args.head_conf_file, 253788).get_head()
...
optimizer = torch.optim.AdamW(params=[{"params": backbone.parameters()},{"params": head1.parameters()},{"params": head2.parameters()}], lr = args.lr, weight_decay=0.05)
...
criterion = torch.nn.CrossEntropyLoss().to(args.local_rank)

This can't work correctly WARNING: torch.distributed.elastic.multiprocessing.api: Sending process 2919 closing signal SIGTERM and I wonder how to declare the optimizer

It's unclear from your error what exactly went wrong and where.

Let's try a few things:

  1. How do you merge two datasets with different labels into the same DataLoader ? how can you ensure that the first label belongs to head1 while the second to head2 ?
    I would recommend two Datasets with two DataLoaders to be iterated using a zip function:
for batch_idx, (image1, label1), (image2, label2) in enumerate(zip(data_loader1, dataloader2)):
  # ...
  1. I wonder if the way you pass all the parameters to the optimizer is not causing trouble.
    How about wrapping everything into a single nn.Container ?
container = nn.ModuleList([backbone, head1, head2])
optimizer = torch.optim.AdamW(container.parameters(), lr=...)

If you need to wrap the model for parallelism, you can wrap the container .


Additionally, Are you training in a multi-GPU settings? In that case, you need to add a few pieces to the code:

  1. Wrap both backbone and head s with a distributed module.
  2. You data loaders need to be aware of parallelism. This is usually done via DistributedSampler .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM