Skip to content

improve _optim_ckpt_wrapper so it is a drop in replacement of optimizer #2052

Open
@felipemello1

Description

@felipemello1

Our recipes are cluttered with logic that checks "if optim_in_bwd".

With a bit of engineering, we can make it a drop in replacement of optimizer, and avoid code like this:

if not self._optimizer_in_bwd:
    self._optimizer.zero_grad()
else:
    for opt in self._optim_ckpt_wrapper.optim_map.values():
        opt.zero_grad()

That can be replaced with:

class MyOptWrapper:
	def __init__(self, optimizers):
		self.optimizers = optimizers

	def zero_grad():
		for opt in self.optimizers.optim_map.values():
        	opt.zero_grad()

optimizer = MyOptWrapper(optimizers)
optimizer.zero_grad()

It may break things from time to time, but good testing should avoid errors hitting prod. For overly complex situations, e.g. checkpointing, we can still do if/else, but we definitely don't need every if/else that we have today: A total of 8.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    best practiceThings we should be doing but aren't

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions