简体   繁体   中英

Wandb training kills kernel in jupyter lab

In my jupyter I can train my model on batch_size=8, but when I use wandb always after 9 iterations the process is killed and kernel restarts. What's more weird is that the same code worked on colab, but with my GPU (RTX 3080) I can never finish the process.

Does anyone have any idea how to overcome this issue?

Edit: I noticed that the kernel dies every time it tries to log the gradients to wandb. Can this be solved?

Code with wandb:

def train_batch(images, labels, model, optimizer, criterion):
    images, labels = images.to(device), labels.to(device)
    
    # Forward pass ➡
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()
    
    size = images.size(0)
    del images, labels
    return loss, size

from loss import YoloLoss

# train the model
def train(model, train_dl, criterion, optimizer, config, is_one_batch):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, criterion, log="all", log_freq=10)

    example_ct = 0  # number of examples seen
    batch_ct = 0
    
    # enumerate epochs
    for epoch in range(config.epochs):
        running_loss = 0.0
        
        if not is_one_batch:
            for i, (inputs, _, targets) in enumerate(train_dl):
                loss, batch_size = train_batch(inputs, targets, model, optimizer, criterion)
                running_loss += loss.item() * batch_size
        else:
            # for one batch only
            loss, batch_size = train_batch(train_dl[0], train_dl[2], model, optimizer, criterion)
            running_loss += loss.item() * batch_size
            
        epoch_loss = running_loss / len(train_dl)
#         loss_values.append(epoch_loss)
        wandb.log({"epoch": epoch, "avg_batch_loss": epoch_loss})
#         wandb.log({"epoch": epoch, "loss": loss}, step=example_ct)
        print("Average epoch loss {}".format(epoch_loss))
def make(config, is_one_batch, data_predefined=True):
    optimizers = {
        "Adam":torch.optim.Adam,
        "SGD":torch.optim.SGD
    }
    
    if data_predefined:
        train_dl, test_dl = train_dl_predef, test_dl_predef
    else:
        train_dl, test_dl = dataset.prepare_data()
        
    if is_one_batch:
        train_dl = next(iter(train_dl))
        test_dl = train_dl
    
    # Make the model
    model = architecture.darknet(config.batch_norm)
    model.to(device)

    # Make the loss and optimizer
    criterion = YoloLoss()
    optimizer = optimizers[config.optimizer](
        model.parameters(), 
        lr=config.learning_rate,
        momentum=config.momentum
    )
    
    return model, train_dl, test_dl, criterion, optimizer
        
def model_pipeline(hyp, is_one_batch=False, device=device):
    with wandb.init(project="YOLO-recreated", entity="bindas1", config=hyp):
        config = wandb.config
        
        # make the model, data, and optimization problem
        model, train_dl, test_dl, criterion, optimizer = make(config, is_one_batch)
        
        # and use them to train the model
        train(model, train_dl, criterion, optimizer, config, is_one_batch)
        
    return model

Code without wandb:

def train_model(train_dl, model, is_one_batch=False):
    # define the optimization
    criterion = YoloLoss()
    optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    
    # for loss plotting
    loss_values = []
    
    # enumerate epochs
    for epoch in tqdm(range(EPOCHS)):
        if epoch % 10 == 0:
            print(epoch)
        running_loss = 0.0
        
        if not is_one_batch:
        # enumerate mini batches
            for i, (inputs, _, targets) in enumerate(train_dl):
                inputs = inputs.to(device)
                targets = targets.to(device)
                # clear the gradients
                optimizer.zero_grad()
                # compute the model output
                yhat = model(inputs)
                # calculate loss
                loss = criterion(yhat, targets)
                # credit assignment
                loss.backward()
#                 print(loss)
                running_loss =+ loss.item() * inputs.size(0)
                # update model weights
                optimizer.step()
        else:
            # for one batch only
            with torch.autograd.detect_anomaly():
                inputs, targets = train_dl[0].to(device), train_dl[2].to(device)
                optimizer.zero_grad()
                # compute the model output
                yhat = model(inputs)
                # calculate loss
                loss = criterion(yhat, targets)
                # credit assignment
                loss.backward()
                print(loss)
                running_loss =+ loss.item() * inputs.size(0)
                # update model weights
                optimizer.step()
        loss_values.append(running_loss / len(train_dl))
    
    plot_loss(loss_values)

model = architecture.darknet()
model.to(device)
optimizer = SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
train_dl_main, test_dl_main = train_dl_predef, test_dl_predef
one_batch = next(iter(train_dl_main))
train_model_wandb(one_batch, model, is_one_batch=True)

Hmm, strange, so in your edit you're saying that it works ok if you remove wandb.watch ?

To double check, have you tried the original code while on the latest version of wandb (0.12.7)?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM