Skip to content

Conversation

@pablomlago
Copy link
Collaborator

TBA

@pablomlago pablomlago force-pushed the feat-round-refactor branch from 24a1c36 to 224d626 Compare June 24, 2025 08:39
@nickfraser nickfraser mentioned this pull request Sep 29, 2025
28 tasks
@pablomlago pablomlago marked this pull request as ready for review September 30, 2025 13:51
@nickfraser nickfraser added the next release PRs which should be merged for the next release label Oct 13, 2025
@Giuseppe5
Copy link
Collaborator

Is there a chance to split this into more manageable chunks?

@nickfraser nickfraser requested a review from Giuseppe5 December 1, 2025 14:37
Copy link
Collaborator

@Giuseppe5 Giuseppe5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First review

from brevitas_examples.common.learned_round.learned_round_method import MSELoss
from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss

OPTIMIZER_MAP = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should implement a way to register new and/or custom optimizer to this map?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for removing this dict, and having the optimizers retrieved from either torch.optim or brevitas.optim.


OPTIMIZER_MAP = {
"sign_sgd": SignSGD,}
LR_SCHEDULER_MAP = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for removing this dict since, at the moment, there is no plan to introduce a custom one. Moreover, this could be registered in brevitas.optim.lr_scheduler and use a logic similar as the one for the optimizers.


@property
def param_fn(self) -> Callable[[nn.Module, OrderedDict, str], bool]:
return {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit ugly, we can have a separate attribute/property with the dict, and this just return self.pre_dict[self.value]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a Registry class in python_utils.py, and I use that for tracking/retrieving the different implementations.


@property
def loss_class(self) -> Type[BlockLoss]:
return {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above


# Both `get_round_parameters` and `get_scale_parameters` are meant to be passed as the argument `get_target`
# of `_get_target_parameters`, which iterates over the modules of a model in a recursive function.
# In the case of `get_round_parameters` the return value indicates whether the submodules of a given module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You say:

In the case of get_round_parameters

as if it is different in get_scale_parameters but it is not really the case (plus/minus some very small differences), but overall the comment can be adjusted (and shortened)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I polished it a bit to make it clearer that it is an example.

raise StopFwdException


def get_blocks(model: nn.Module, block_check_fn: Callable[[nn.Module, str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we might have something similar to this for quantization

optimizer.zero_grad()
for lr_scheduler in lr_schedulers:
if lr_scheduler:
class LearnedRoundOptimizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a Trainer class of some sorts



@dataclass
class OptimizerArgs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split optimizer/scheduler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the case is strong enough for that, a lr scheduler instance is semantically tied to an optimizer, so it is sensible for the dataclass hierarchy to reflect this.

state_dict = self._get_target_parameters(model, optimizer_target.param_fn, state_dict)
return state_dict

def _create_optimizer_and_scheduler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split optimizer/scheduler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

if lr_scheduler_args is not None else None)
return optimizer, lr_scheduler

def _create_optimizers_lr_schedulers(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name of the function is confusingly similar with the one above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for removing it, after the changes to the optimizer instantiation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

next release PRs which should be merged for the next release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants