Open
Description
Our recipes are cluttered with logic that checks "if optim_in_bwd".
With a bit of engineering, we can make it a drop in replacement of optimizer, and avoid code like this:
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
else:
for opt in self._optim_ckpt_wrapper.optim_map.values():
opt.zero_grad()
That can be replaced with:
class MyOptWrapper:
def __init__(self, optimizers):
self.optimizers = optimizers
def zero_grad():
for opt in self.optimizers.optim_map.values():
opt.zero_grad()
optimizer = MyOptWrapper(optimizers)
optimizer.zero_grad()
It may break things from time to time, but good testing should avoid errors hitting prod. For overly complex situations, e.g. checkpointing, we can still do if/else, but we definitely don't need every if/else that we have today: A total of 8.
Activity