@@ -194,8 +194,7 @@ def _teacher_forced_flow_loss_forward(
194194 input_ids = self .fuse_traj_tokens (input_ids , traj_data_vlm )
195195 device = input_ids .device
196196
197- # Append <traj_future_start> so the expert attends through the full prompt
198- # that inference would have generated up to the action block.
197+ # Append <traj_future_start> so the expert attends through the full prompt.
199198 traj_future_start_id = self .tokenizer .convert_tokens_to_ids (
200199 to_special_token ("traj_future_start" )
201200 )
@@ -407,28 +406,21 @@ def quantize_model(model, args, tokenizer=None, calibration_forward_loop=None):
407406 quant_cfg = copy .deepcopy (mtq .NVFP4_DEFAULT_CFG )
408407 else :
409408 raise RuntimeError ("Unsupported quantization format" )
410- # Keep the entire vision tower in high precision. We must clear the NVFP4 quantizer
411- # *type* here, not merely disable it: the QuantConv3d in the vision patch-embed routes to
412- # a JIT-compiled implicit-GEMM CUDA kernel whenever its quantizers are NVFP4-typed (num_bits
413- # == (2, 1) with a dynamic block config) -- even when `enable=False`. That path requires
414- # CUDA_HOME (kernel compilation) and would also fake-quantize the vision weights we intend to
415- # leave untouched. Passing a non-NVFP4 cfg (num_bits=8) together with enable=False keeps these
416- # modules on the plain, unquantized forward path. Harmless for FP8 (already disabled there).
409+ # Keep the vision tower in high precision. Pass a non-NVFP4 cfg (num_bits=8) with
410+ # enable=False, not just enable=False: an NVFP4-typed QuantConv3d routes to a JIT
411+ # implicit-GEMM CUDA kernel (needs CUDA_HOME) even when disabled.
417412 quant_cfg ["quant_cfg" ].append (
418413 {"quantizer_name" : "*vlm.model.visual*" , "enable" : False , "cfg" : {"num_bits" : 8 }}
419414 )
420415
421- if args .quant_format == "nvfp4" :
422- # NVFP4 packs weights in blocks of 16 along the input (K) dimension. A Linear whose
423- # in_features is not a multiple of 16 gets K-padded when its weight is packed, and
424- # ModelOpt's packed-weight dequantize path cannot reshape the padded buffer back to the
425- # logical shape (it raises e.g. "shape '[512, 60]' is invalid for input of size 32768").
426- # Such layers also never satisfy the real-quant GEMM's K % 64 == 0 requirement, so they
427- # would only ever run on the (now-broken) dequantize fallback. Keep them in high precision.
428- # In AlpamayoR1 these are the small action-projection heads (e.g. the Fourier-feature
429- # encoder input), so the size/speed impact of leaving them unquantized is negligible.
416+ if args .quant_format == "nvfp4" or getattr (args , "real_quant" , False ):
417+ # Keep Linear layers whose in/out features aren't multiples of 16 in high precision:
418+ # they break the real-quant GEMM backends (NVFP4 block packing, FP8 torch._scaled_mm).
419+ # In AlpamayoR1 these are the small action-projection heads, so the impact is negligible.
430420 for _name , _module in model .named_modules ():
431- if isinstance (_module , torch .nn .Linear ) and _module .in_features % 16 != 0 :
421+ if isinstance (_module , torch .nn .Linear ) and (
422+ _module .in_features % 16 != 0 or _module .out_features % 16 != 0
423+ ):
432424 quant_cfg ["quant_cfg" ].append ({"quantizer_name" : f"{ _name } .*" , "enable" : False })
433425
434426 model = mtq .quantize (model , quant_cfg , forward_loop = calibrate_loop )
@@ -519,15 +511,24 @@ def loss_func(output, batch):
519511 print (f"[autoquant-loss] loss={ loss .item ():.6g} finite={ torch .isfinite (loss ).item ()} " )
520512 return loss
521513
522- # try:
514+ # Mirror the quantize_model exclusions via disabled_layers (fnmatch against module names),
515+ # since the AutoQuantize search also includes NVFP4: keep the vision tower unquantized, and
516+ # exclude Linear layers whose in/out features aren't multiples of 16.
517+ disabled_layers = ["*lm_head*" , "*vlm.model.visual*" ]
518+ for _name , _module in model .named_modules ():
519+ if isinstance (_module , torch .nn .Linear ) and (
520+ _module .in_features % 16 != 0 or _module .out_features % 16 != 0
521+ ):
522+ disabled_layers .append (_name )
523+
523524 model , search_state = mtq .auto_quantize (
524525 model ,
525526 constraints = {"effective_bits" : args .auto_quantize_bits },
526527 quantization_formats = ["NVFP4_DEFAULT_CFG" , "FP8_DEFAULT_CFG" ],
527528 data_loader = data_loader ,
528529 forward_step = forward_step ,
529530 loss_func = loss_func ,
530- disabled_layers = "*lm_head*" ,
531+ disabled_layers = disabled_layers ,
531532 verbose = True ,
532533 )
533534
@@ -565,13 +566,13 @@ def main():
565566 ap .add_argument (
566567 "--auto_quantize_bits" ,
567568 type = float ,
568- default = 4.8 ,
569+ default = 6.5 ,
569570 help = "Effective-bits budget for AutoQuantize (only used when --quantize auto)" ,
570571 )
571572 ap .add_argument (
572573 "--parquet" ,
573574 type = str ,
574- default = "1005_7cam_gold_eval_metadb_public .parquet" ,
575+ default = "0417_16rows_train_set_for_calibration_25.10 .parquet" ,
575576 help = "Parquet file with clip_ids for calibration" ,
576577 )
577578 ap .add_argument ("--t0_us" , type = int , default = 5_100_000 )
@@ -616,6 +617,7 @@ def main():
616617 weight_only = False ,
617618 debug = True ,
618619 auto_quantize_bits = args .auto_quantize_bits ,
620+ real_quant = args .real_quant ,
619621 )
620622 if args .quantize == "auto" :
621623 model = auto_quantize_model (
@@ -654,19 +656,13 @@ def main():
654656 print (f"Saving quantized checkpoint to { args .output_dir !r} ..." )
655657
656658 if args .real_quant :
657- # Real (packed) quantization. `mtq.compress` replaces the quantized linears with
658- # RealQuantLinear modules whose weights are packed into the low-precision storage
659- # format (NVFP4 = E2M1 nibbles + per-block FP8 scales) and enables ModelOpt's
660- # real-quant GEMM kernels, so inference runs on the hardware NVFP4 path rather than
661- # fake-quant fp16. We then save through the ModelOpt-patched `save_pretrained`, which
662- # writes the packed weights *and* a `modelopt_state.pth` recording the quantize +
663- # real_quantize modes (including the packed-tensor metadata/scales). Reloading via
664- # `AlpamayoR1.from_pretrained` with ModelOpt HF checkpointing enabled replays those
665- # modes and re-wraps the packed weights, so the checkpoint loads and runs real-quantized.
659+ # Real (packed) quantization. `mtq.compress` packs weights into the low-precision
660+ # storage format and enables ModelOpt's real-quant GEMM kernels. The ModelOpt-patched
661+ # `save_pretrained` writes the packed weights plus a `modelopt_state.pth`, which
662+ # `AlpamayoR1.from_pretrained` replays to reload and run real-quantized.
666663 #
667- # NOTE: `export_hf_checkpoint` (the unified vLLM/TRT-LLM deployment format) is
668- # intentionally not used here: that format has no `modelopt_state.pth`, so a custom
669- # model class like AlpamayoR1 cannot reload it through `from_pretrained`.
664+ # NOTE: `export_hf_checkpoint` (the vLLM/TRT-LLM deployment format) isn't used here: it
665+ # has no `modelopt_state.pth`, so a custom model class can't reload it via from_pretrained.
670666 mtq .compress (model )
671667 model .eval ()
672668 with torch .inference_mode ():
0 commit comments