-
Notifications
You must be signed in to change notification settings - Fork 239
Feat (ex/common): learned round refactor #1323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
24a1c36 to
224d626
Compare
224d626 to
577907e
Compare
|
Is there a chance to split this into more manageable chunks? |
Giuseppe5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First review
| from brevitas_examples.common.learned_round.learned_round_method import MSELoss | ||
| from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss | ||
|
|
||
| OPTIMIZER_MAP = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should implement a way to register new and/or custom optimizer to this map?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for removing this dict, and having the optimizers retrieved from either torch.optim or brevitas.optim.
|
|
||
| OPTIMIZER_MAP = { | ||
| "sign_sgd": SignSGD,} | ||
| LR_SCHEDULER_MAP = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for removing this dict since, at the moment, there is no plan to introduce a custom one. Moreover, this could be registered in brevitas.optim.lr_scheduler and use a logic similar as the one for the optimizers.
|
|
||
| @property | ||
| def param_fn(self) -> Callable[[nn.Module, OrderedDict, str], bool]: | ||
| return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit ugly, we can have a separate attribute/property with the dict, and this just return self.pre_dict[self.value]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created a Registry class in python_utils.py, and I use that for tracking/retrieving the different implementations.
|
|
||
| @property | ||
| def loss_class(self) -> Type[BlockLoss]: | ||
| return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
|
|
||
| # Both `get_round_parameters` and `get_scale_parameters` are meant to be passed as the argument `get_target` | ||
| # of `_get_target_parameters`, which iterates over the modules of a model in a recursive function. | ||
| # In the case of `get_round_parameters` the return value indicates whether the submodules of a given module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You say:
In the case of
get_round_parameters
as if it is different in get_scale_parameters but it is not really the case (plus/minus some very small differences), but overall the comment can be adjusted (and shortened)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I polished it a bit to make it clearer that it is an example.
| raise StopFwdException | ||
|
|
||
|
|
||
| def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we might have something similar to this for quantization
| optimizer.zero_grad() | ||
| for lr_scheduler in lr_schedulers: | ||
| if lr_scheduler: | ||
| class LearnedRoundOptimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a Trainer class of some sorts
|
|
||
|
|
||
| @dataclass | ||
| class OptimizerArgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split optimizer/scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if the case is strong enough for that, a lr scheduler instance is semantically tied to an optimizer, so it is sensible for the dataclass hierarchy to reflect this.
| state_dict = self._get_target_parameters(model, optimizer_target.param_fn, state_dict) | ||
| return state_dict | ||
|
|
||
| def _create_optimizer_and_scheduler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split optimizer/scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
| if lr_scheduler_args is not None else None) | ||
| return optimizer, lr_scheduler | ||
|
|
||
| def _create_optimizers_lr_schedulers( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name of the function is confusingly similar with the one above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for removing it, after the changes to the optimizer instantiation.
b700947 to
f580604
Compare
TBA