@@ -32,13 +32,13 @@ def param_fn(self) -> Callable[[nn.Module, OrderedDict, str], bool]:
3232 TargetParametrizations .LEARNED_ROUND : get_round_parameters }[self ]
3333
3434
35- class BlockLoss (Enum ):
35+ class BlockLossType (Enum ):
3636 MSE = "mse"
3737 REGULARISED_MSE = "regularised_mse"
3838
3939 @property
4040 def loss_class (self ) -> Type [BlockLoss ]:
41- return {BlockLoss .MSE : MSELoss , BlockLoss .REGULARISED_MSE : RegularisedMSELoss }[self ]
41+ return {BlockLossType .MSE : MSELoss , BlockLossType .REGULARISED_MSE : RegularisedMSELoss }[self ]
4242
4343
4444# TODO: Decide whether it is worth grouping the get_target_parameters under a class
@@ -170,7 +170,10 @@ class TrainingArgs:
170170 batch_size : int = field (default = 8 , metadata = {"help" : "Batch size per GPU for training." })
171171 iters : int = field (default = 200 , metadata = {"help" : "Number of training iterations." })
172172 loss_cls : Union [str , Type [BlockLoss ]] = field (
173- default = "mse" , metadata = {"help" : "Class of the loss to be used for rounding optimization." })
173+ default = "mse" ,
174+ metadata = {
175+ "help" : "Class of the loss to be used for rounding optimization." ,
176+ "choices" : [block_loss_type .value for block_loss_type in BlockLossType ]})
174177 loss_kwargs : Optional [Union [Dict , str ]] = field (
175178 default = None ,
176179 metadata = {"help" : "Extra keyword arguments for the learned round loss." },
@@ -213,7 +216,7 @@ def __post_init__(self) -> None:
213216 self .amp_dtype , str ) else self .amp_dtype
214217 # Retrieve loss
215218 self .loss_cls = (
216- BlockLoss (self .loss_cls ).loss_class
219+ BlockLossType (self .loss_cls ).loss_class
217220 if isinstance (self .loss_cls , str ) else self .loss_cls )
218221
219222
0 commit comments