Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/brevitas/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .cailey_sgd import CaileySGD
from .sign_sgd import SignSGD
97 changes: 97 additions & 0 deletions src/brevitas/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@
# SPDX-License-Identifier: BSD-3-Clause

from contextlib import contextmanager
from dataclasses import is_dataclass
from enum import Enum
import functools
import json
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import TypeVar
from typing import Union


class AutoName(str, Enum):
Expand Down Expand Up @@ -64,3 +75,89 @@ def run(*args, **kwargs):
return function(*args, **kwargs)

return run


def convert_str_dict(passed_value: Dict) -> Dict:
"Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
for key, value in passed_value.items():
if isinstance(value, dict):
passed_value[key] = convert_str_dict(value)
elif isinstance(value, str):
# First check for bool and convert
if value.lower() in ("true", "false"):
passed_value[key] = value.lower() == "true"
# Check for digit
elif value.isdigit():
passed_value[key] = int(value)
elif value.replace(".", "", 1).isdigit():
passed_value[key] = float(value)

return passed_value


def parse_dataclass_dicts(data_cls: Any, dict_attributes: List[str]) -> None:
"""
Parses the strings in `dict_attributes` of dataclass `data_cls` to dictionaries.
"""
assert is_dataclass(data_cls), f"data_cls must be a dataclass, but got {type(data_cls)}"
for attr in dict_attributes:
if not hasattr(data_cls, attr):
raise ValueError(f"Dataclass {type(data_cls).__name__} has no attribute named {attr}")
kwargs = getattr(data_cls, attr)

if kwargs is None:
kwargs = {}
elif isinstance(kwargs, str):
# Parse in args that could be `dict` sent in from the CLI as a string
kwargs = json.loads(kwargs)
# Convert str values to types if applicable
kwargs = convert_str_dict(kwargs)
elif isinstance(kwargs, dict):
pass
else:
# Raise an error if the attribute cannot be parsed into a dictionary
raise ValueError(
f"Value set for attribute {attr} of dataclass {type(data_cls).__name__} cannot be converted into a dictionary."
)
# Set the updated value
setattr(data_cls, attr, kwargs)


T = TypeVar("T")


class Registry(Generic[T]):

def __init__(self, registry_name: Optional[str] = None) -> None:
self._registry_name = registry_name
self._registry: Dict[str, T] = {}

@property
def registry_name(self) -> str:
return "registry" if self._registry_name is None else self._registry_name

def register(self, names: Union[str, List[str]]) -> Callable[[T], T]:
if isinstance(names, str):
names = [names]

def decorator(value: T) -> T:
# Allow registering the same value to multiple keys
for name in names:
if name in self._registry:
raise ValueError(f"'{name}' is already registered in {self.registry_name}.")
self._registry[name] = value
return value

return decorator

def get_registered_keys(self) -> Iterable[str]:
return self._registry.keys()

def get(self, name: str) -> T:
try:
return self._registry[name]
except KeyError:
available = ", ".join(sorted(self._registry)) or "<empty>"
raise ValueError(
f"'{name}' not found in {self.registry_name}. The available values are: {available}"
)
248 changes: 248 additions & 0 deletions src/brevitas_examples/common/learned_round/learned_round_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass
from dataclasses import field
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union
import warnings

import torch
from torch.optim.optimizer import Optimizer

from brevitas import optim
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.utils.python_utils import parse_dataclass_dicts
from brevitas_examples.common.learned_round.learned_round_method import BLOCK_LOSS_REGISTRY
from brevitas_examples.common.learned_round.learned_round_method import BlockLoss
from brevitas_examples.common.learned_round.learned_round_method import TARGET_PARAM_FN_REGISTRY
from brevitas_examples.common.learned_round.learned_round_method import TargetParamFn

OPTIMIZER_NAMESPACES = [torch.optim, optim]
LR_SCHEDULER_NAMESPACES = [torch.optim.lr_scheduler]


def _parse_optimizer_class(optimizer_str: str) -> Type[Optimizer]:
optimizer_namespace_keys = []
for namespace in OPTIMIZER_NAMESPACES:
optimizer_namespace_keys += [
(namespace, optimizer_key)
for optimizer_key in namespace.__dict__.keys()
# Make sure that only valid optimizer implementations are
# retrieved, when matching the string passed by the user
if (
# Verify that the key stars with the one passed by the user
optimizer_key.lower() == optimizer_str.lower() and
# Verify that key corresponds to a class
isinstance(namespace.__dict__[optimizer_key], type) and
# Make sure the abstract class is not used
optimizer_key != "Optimizer" and
# An optimizer implements zero_grad and step. Check that this
# is the case for the class retrieved from torch.optim
hasattr(namespace.__dict__[optimizer_key], 'step') and
callable(namespace.__dict__[optimizer_key].step) and
hasattr(namespace.__dict__[optimizer_key], 'zero_grad') and
callable(namespace.__dict__[optimizer_key].zero_grad))]

if len(optimizer_namespace_keys) == 0:
raise ValueError(
f"{optimizer_str} is not a valid optimizer in namespaces {[_namespace.__name__ for _namespace in OPTIMIZER_NAMESPACES]}."
)

namespace, optimizer_name = optimizer_namespace_keys[0]
if len(optimizer_namespace_keys) > 1:
warnings.warn(
f"There are multiple potential matches for optimizer {optimizer_str} ({[_optimizer_name for _, _optimizer_name in optimizer_namespace_keys]}). "
f"Defaulting to {optimizer_name} from {namespace.__name__}.")

optimizer_class = getattr(namespace, optimizer_name)
return optimizer_class


def _parse_lr_scheduler_class(lr_scheduler_str: str) -> Type:
lr_scheduler_namespace_keys = []
for namespace in LR_SCHEDULER_NAMESPACES:
lr_scheduler_namespace_keys += [
(namespace, lr_scheduler_key)
for lr_scheduler_key in torch.optim.lr_scheduler.__dict__.keys()
# Check for making sure that only valid LRScheduler implementations are
# retrived, when matching with the string passed by the user
if
((
lr_scheduler_key.lower() == lr_scheduler_str.lower() or
lr_scheduler_key.lower() == lr_scheduler_str.lower() + "lr") and
# Verify that key corresponds to a class
isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and
# Make sure the abstract class is not retrieved
lr_scheduler_key != "LRScheduler" and
# A learning rate scheduler implements zero_grad and step. Check that this
# is the case for the class retrieved from torch.optim.lr_scheduler
hasattr(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], 'step') and
callable(torch.optim.lr_scheduler.__dict__[lr_scheduler_key].step))]

if len(lr_scheduler_namespace_keys) == 0:
raise ValueError(
f"{lr_scheduler_str} is not a valid lr scheduler in namespaces {[_namespace.__name__ for _namespace in LR_SCHEDULER_NAMESPACES]}."
)

namespace, lr_scheduler_name = lr_scheduler_namespace_keys[0]
if len(lr_scheduler_namespace_keys) > 1:
warnings.warn(
f"There are multiple potential matches for lr scheduler {lr_scheduler_str} ({[_lr_scheduler_name for _, _lr_scheduler_name in lr_scheduler_namespace_keys]}). "
f"Defaulting to {lr_scheduler_name} from {namespace.__name__}.")

lr_scheduler_class = getattr(namespace, lr_scheduler_name)
return lr_scheduler_class


@dataclass
class LRSchedulerArgs:
lr_scheduler_cls: Union[str, Type] = field(
default="linear",
metadata={"help": "The learning rate scheduler to use."},
)
lr_scheduler_kwargs: Optional[Union[Dict, str]] = field(
default=None,
metadata={"help": ("Extra keyword arguments for the learning rate "
"scheduler.")},
)

# The attributes in _DICT_ATTRIBUTES are parsed to dictionaries.
_DICT_ATTRIBUTES = ["lr_scheduler_kwargs"]

def __post_init__(self) -> None:
# Parse in args that could be `dict` sent in from the CLI as a string
parse_dataclass_dicts(self, self._DICT_ATTRIBUTES)
# Parse string to learning rate scheduler class if needed
self.lr_scheduler_cls = (
_parse_lr_scheduler_class(self.lr_scheduler_cls) if isinstance(
self.lr_scheduler_cls, str) else self.lr_scheduler_cls)


@dataclass
class OptimizerArgs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split optimizer/scheduler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the case is strong enough for that, a lr scheduler instance is semantically tied to an optimizer, so it is sensible for the dataclass hierarchy to reflect this.

target_params: Union[str, TargetParamFn] = field(
metadata={
"help": ("Targets to be optimized."),
"choices": TARGET_PARAM_FN_REGISTRY.get_registered_keys(),})
optimizer_cls: Union[str, Type[Optimizer]] = field(
default="adam",
metadata={"help": "The optimizer to use."},
)
lr: float = field(
default=1e-3,
metadata={"help": "Initial learning rate for the optimizer."},
)
optimizer_kwargs: Optional[Union[Dict, str]] = field(
default=None,
metadata={"help": "Extra keyword arguments for the optimizer."},
)
lr_scheduler_args: Optional[LRSchedulerArgs] = field(
default=None,
metadata={
"help": ("Hyperparameters of learning rate scheduler for the selected"
"optimizer.")},
)

_DICT_ATTRIBUTES = ["optimizer_kwargs"]

def __post_init__(self) -> None:
# Parse args that could be `dict` sent in from the CLI as a string
parse_dataclass_dicts(self, self._DICT_ATTRIBUTES)
# Parse optimizer name to class
self.optimizer_cls = (
_parse_optimizer_class(self.optimizer_cls)
if isinstance(self.optimizer_cls, str) else self.optimizer_cls)
# Initialize the target parametrizations
self.target_params = (
TARGET_PARAM_FN_REGISTRY.get(self.target_params)
if isinstance(self.target_params, str) else self.target_params)
if self.lr < 0:
raise ValueError(f"Expected a positive learning rate but {self.lr} was passed.")


@dataclass
class TrainingArgs:
optimizers_args: List[OptimizerArgs] = field(
metadata={"help": ("Hyperparameters of the optimizers to use during training.")})
batch_size: int = field(default=8, metadata={"help": "Batch size per GPU for training."})
iters: int = field(default=200, metadata={"help": "Number of training iterations."})
loss_cls: Union[str, Type[BlockLoss]] = field(
default="mse",
metadata={
"help": "Class of the loss to be used for rounding optimization.",
"choices": BLOCK_LOSS_REGISTRY.get_registered_keys()})
loss_kwargs: Optional[Union[Dict, str]] = field(
default=None,
metadata={"help": "Extra keyword arguments for the learned round loss."},
)
loss_scaling_factor: float = field(
default=1.,
metadata={"help": "Scaling factor for the loss."},
)
use_best_model: bool = field(
default=True,
metadata={
"help":
("Whether to use the best setting of the learned round found "
"during training.")})
use_amp: bool = field(
default=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want this true?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the previous default value, so we might want to leave it as it is for retrocompatibility.

metadata={"help": "Whether to train using PyTorch Automatic Mixed Precision."})
amp_dtype: Union[str, torch.dtype] = field(
default=torch.float16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want this to have a default? I'd have the user specify to make sure they know what they're doing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

metadata={
"choices": ["float16", "bfloat16"], "help": "Dtype for mixed-precision training."})

_DICT_ATTRIBUTES = ["loss_kwargs"]

def __post_init__(self) -> None:
# Parse in args that could be `dict` sent in from the CLI as a string
parse_dataclass_dicts(self, self._DICT_ATTRIBUTES)

for optimizer_args in self.optimizers_args:
# Check if the optimizer has an attached learning rate scheduler
if optimizer_args.lr_scheduler_args is not None:
optimizer_args.lr_scheduler_args.lr_scheduler_kwargs["total_iters"] = self.iters
# Parse amp_dtype
self.amp_dtype = getattr(torch, self.amp_dtype) if isinstance(
self.amp_dtype, str) else self.amp_dtype
# Retrieve loss
self.loss_cls = (
BLOCK_LOSS_REGISTRY.get(self.loss_cls)
if isinstance(self.loss_cls, str) else self.loss_cls)


@dataclass
class LearnedRoundArgs:
learned_round_param: Union[str, LearnedRoundImplType] = field(
default="identity",
metadata={
"help": "Defines the functional form of the learned round parametrization.",
"choices": [param.value.lower() for param in LearnedRoundImplType]})
learned_round_kwargs: Optional[Union[Dict, str]] = field(
default=None,
metadata={"help": "Extra keyword arguments for the learned round parametrization."},
)
fast_update: bool = field(
default=True, metadata={"help": ("Whether to use fast update with learned round.")})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a line to say what's the drawback, if any.
If there are none, why would I never not use it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The drawback is that it requires implementing extra methods in the Cache and it is not valid for every topology.


_DICT_ATTRIBUTES = ["learned_round_kwargs"]

def __post_init__(self) -> None:
# Parse in args that could be `dict` sent in from the CLI as a string
parse_dataclass_dicts(self, self._DICT_ATTRIBUTES)

self.learned_round_param = LearnedRoundImplType(
self.learned_round_param.upper()) if isinstance(
self.learned_round_param, str) else self.learned_round_param


@dataclass
class Config:
learned_round_args: LearnedRoundArgs = field(
metadata={"help": "Learned round parametrization."})
training_args: TrainingArgs = field(metadata={"help": "Hyperparameters for optimization."})
Loading