@@ -245,20 +245,24 @@ def create_weights(
245
245
layer .register_parameter ("input_scale" , None )
246
246
247
247
def process_weights_after_loading (self , layer : Module ) -> None :
248
- # Block quant doesn't need to process weights after loading
248
+ # TODO(rob): refactor block quant into separate class.
249
249
if self .block_quant :
250
+ assert self .quant_config .activation_scheme == "dynamic"
250
251
if current_platform .is_rocm ():
251
- weight , weight_scale , _ = \
252
+ weight , weight_scale_inv , _ = \
252
253
normalize_e4m3fn_to_e4m3fnuz (
253
254
weight = layer .weight ,
254
- weight_scale = layer .weight_scale_inv ,
255
- input_scale = layer .input_scale )
256
- layer .weight = Parameter (weight , requires_grad = False )
257
- layer .weight_scale_inv = Parameter (weight_scale ,
258
- requires_grad = False )
255
+ weight_scale = layer .weight_scale_inv )
256
+ else :
257
+ weight = layer .weight .data
258
+ weight_scale_inv = layer .weight_scale_inv .data
259
+
260
+ # Torch.compile cannot use Parameter subclasses.
261
+ layer .weight = Parameter (weight , requires_grad = False )
262
+ layer .weight_scale_inv = Parameter (weight_scale_inv ,
263
+ requires_grad = False )
259
264
return
260
- layer .weight = torch .nn .Parameter (layer .weight .data ,
261
- requires_grad = False )
265
+
262
266
# If checkpoint not serialized fp8, quantize the weights.
263
267
if not self .quant_config .is_checkpoint_fp8_serialized :
264
268
qweight , weight_scale = ops .scaled_fp8_quant (layer .weight ,
@@ -507,8 +511,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
507
511
layer .w2_input_scale = None
508
512
509
513
def process_weights_after_loading (self , layer : Module ) -> None :
510
- # Block quant doesn't need to process weights after loading
514
+ # TODO (rob): refactor block quant into separate class.
511
515
if self .block_quant :
516
+ assert self .quant_config .activation_scheme == "dynamic"
512
517
if current_platform .is_rocm ():
513
518
w13_weight , w13_weight_scale_inv , w13_input_scale = \
514
519
normalize_e4m3fn_to_e4m3fnuz (
@@ -518,22 +523,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
518
523
normalize_e4m3fn_to_e4m3fnuz (
519
524
layer .w2_weight , layer .w2_weight_scale_inv ,
520
525
layer .w2_input_scale )
521
- # Reset the parameter
522
- layer .w13_weight = torch .nn .Parameter (w13_weight ,
523
- requires_grad = False )
524
- layer .w13_weight_scale_inv = torch .nn .Parameter (
525
- w13_weight_scale_inv , requires_grad = False )
526
- if w13_input_scale is not None :
527
- layer .w13_input_scale = torch .nn .Parameter (
528
- w13_input_scale , requires_grad = False )
529
- layer .w2_weight = torch .nn .Parameter (w2_weight ,
530
- requires_grad = False )
531
- layer .w2_weight_scale_inv = torch .nn .Parameter (
532
- w2_weight_scale_inv , requires_grad = False )
533
- if w2_input_scale is not None :
534
- layer .w2_input_scale = torch .nn .Parameter (
535
- w2_input_scale , requires_grad = False )
526
+ else :
527
+ w13_weight = layer .w13_weight .data
528
+ w13_weight_scale_inv = layer .w13_weight_scale_inv .data
529
+ w2_weight = layer .w2_weight
530
+ w2_weight_scale_inv = layer .w2_weight_scale_inv
531
+
532
+ # torch.compile() cannot use Parameter subclasses.
533
+ layer .w13_weight = Parameter (w13_weight , requires_grad = False )
534
+ layer .w13_weight_scale_inv = Parameter (w13_weight_scale_inv ,
535
+ requires_grad = False )
536
+ layer .w2_weight = Parameter (w2_weight , requires_grad = False )
537
+ layer .w2_weight_scale_inv = Parameter (w2_weight_scale_inv ,
538
+ requires_grad = False )
536
539
return
540
+
537
541
# If checkpoint is fp16, quantize in place.
538
542
if not self .quant_config .is_checkpoint_fp8_serialized :
539
543
# If rocm, use float8_e4m3fnuz as dtype
0 commit comments