Skip to content

Commit 24a1c36

Browse files
committed
Minor fixs
1 parent d6c9aba commit 24a1c36

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/brevitas_examples/common/learned_round/learned_round_args.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def param_fn(self) -> Callable[[nn.Module, OrderedDict, str], bool]:
3232
TargetParametrizations.LEARNED_ROUND: get_round_parameters}[self]
3333

3434

35-
class BlockLoss(Enum):
35+
class BlockLossType(Enum):
3636
MSE = "mse"
3737
REGULARISED_MSE = "regularised_mse"
3838

3939
@property
4040
def loss_class(self) -> Type[BlockLoss]:
41-
return {BlockLoss.MSE: MSELoss, BlockLoss.REGULARISED_MSE: RegularisedMSELoss}[self]
41+
return {BlockLossType.MSE: MSELoss, BlockLossType.REGULARISED_MSE: RegularisedMSELoss}[self]
4242

4343

4444
# TODO: Decide whether it is worth grouping the get_target_parameters under a class
@@ -170,7 +170,10 @@ class TrainingArgs:
170170
batch_size: int = field(default=8, metadata={"help": "Batch size per GPU for training."})
171171
iters: int = field(default=200, metadata={"help": "Number of training iterations."})
172172
loss_cls: Union[str, Type[BlockLoss]] = field(
173-
default="mse", metadata={"help": "Class of the loss to be used for rounding optimization."})
173+
default="mse",
174+
metadata={
175+
"help": "Class of the loss to be used for rounding optimization.",
176+
"choices": [block_loss_type.value for block_loss_type in BlockLossType]})
174177
loss_kwargs: Optional[Union[Dict, str]] = field(
175178
default=None,
176179
metadata={"help": "Extra keyword arguments for the learned round loss."},
@@ -213,7 +216,7 @@ def __post_init__(self) -> None:
213216
self.amp_dtype, str) else self.amp_dtype
214217
# Retrieve loss
215218
self.loss_cls = (
216-
BlockLoss(self.loss_cls).loss_class
219+
BlockLossType(self.loss_cls).loss_class
217220
if isinstance(self.loss_cls, str) else self.loss_cls)
218221

219222

src/brevitas_examples/common/learned_round/learned_round_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def _training_loop(
521521

522522
scaler = None
523523
if self.config.training_args.use_amp:
524-
scaler = GradScaler(device_type="cuda" if torch.cuda.is_available() else "cpu")
524+
scaler = GradScaler(device="cuda" if torch.cuda.is_available() else "cpu")
525525

526526
# Dictionary to store the rounding parameters yielding the lowest
527527
# training loss

0 commit comments

Comments
 (0)