A small collection of tools to manage deep learning with multiple sources of loss. Based on the paper Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (Kendall, et al; 2017).
Many deep learning situations may call for handling multiple sources of loss. These may be from independent tasks, or for loss from different parts of the same network. The naive solution is to simple take a sum
A dominant approach of works is combining multi objective losses with weighted linear sum of the losses for each individual task,
Simply, we weight each regression loss by
Our implementation adapts Kendall's loss weighting approach using learnable parameters
This allows optimal learning of weights for any type of loss in a multi-task learning setup, without predefining specific treatment for regression or classification tasks.
The model MultiNoiseLoss() is implemented as a torch module. A typical use case (with two classification losses, for example) is
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device) ## Net is some torch module, e.g. with an mlp layer Net.mlp
multi_loss = MultiNoiseLoss(n_losses=2).to(device)
optimizer = torch.optim.Adam([
{'params': model.mlp.parameters()},
{'params': multi_loss.noise_params}], lr = 0.001)
lambda1 = lambda ep: 1 / (2**((ep+11)//10))
lambda2 = lambda ep: 1
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])N.B. We have to include the dynamic weighting parameters in the optimizer, so that they are also updated, as well as handling their learning rate. We could of course increment that LR in the steps as the rest of the model, but this may not be desirable, and the LR of the weighting really depends on the situation. Anecdotally: The learning rate for the dynamic weights is not so important, and it can stay at LR=1e-2 and seems to work fine.
The loss is then called in the training loop as
...
loss_1 = cross_entropy(predictions_1, targets_1)
loss_2 = cross_entropy(predictions_2, targets_2)
loss = multi_loss([loss_1, loss_2])
loss.backward()
optimizer.step()
...In the original paper, regression tasks are weighted by