Since the gradient projection methods worked well, check out TorchJD for automatically balancing losses in a conflict-free way. It could be a clean way to scale up this approach.
Training becomes roughly 2× slower, but you get faster convergence, and while you don’t entirely eliminate loss weightings, it helps substantially.
Gradient projection (which is a single point rather than a curve due to not having an obvious hyperparameter to vary)
TorchJD addresses this—it lets you explicitly vary weight along the Pareto front.
Nice work!
Since the gradient projection methods worked well, check out TorchJD for automatically balancing losses in a conflict-free way. It could be a clean way to scale up this approach.
Training becomes roughly 2× slower, but you get faster convergence, and while you don’t entirely eliminate loss weightings, it helps substantially.
TorchJD addresses this—it lets you explicitly vary weight along the Pareto front.