@@ -410,7 +410,12 @@ def mse_calibrate(
410410 quant_func = partial (_mse_quant_func , quantizer = module ),
411411 )
412412
413- # Identify weight quantizers by checking if they have corresponding weight parameters
413+ # Collect weight quantizers (standard + fused-experts per-expert lists).
414+ try :
415+ from modelopt .torch .quantization .plugins .huggingface import _QuantFusedExperts as _qfe_cls
416+ except ImportError :
417+ _qfe_cls = None # type: ignore[misc]
418+
414419 name_to_module = dict (model .named_modules ())
415420 for parent_module in name_to_module .values ():
416421 if parent_module in seen_modules :
@@ -421,22 +426,56 @@ def mse_calibrate(
421426 if isinstance (weight_quantizer , TensorQuantizer ) and weight_quantizer .is_enabled :
422427 if getattr (weight_quantizer , "_calibrator" , None ) is not None :
423428 weight_quantizers .append ((parent_module , weight_name , weight_quantizer ))
424- # _QuantFusedExperts stores per-expert weight quantizers as nn.ModuleList named
425- # {param_name}_weight_quantizers (plural). Detect this pattern and enqueue each
426- # per-expert quantizer individually. The isinstance(qlist, nn.ModuleList) +
427- # isinstance(wq, TensorQuantizer) check below guards against false positives on
428- # unrelated modules that happen to have similarly-named attributes.
429- for param_name , _ in parent_module .named_parameters (recurse = False ):
430- qlist = getattr (parent_module , f"{ param_name } _weight_quantizers" , None )
431- if not isinstance (qlist , nn .ModuleList ):
432- continue
433- for expert_idx , wq in enumerate (qlist ):
434- if isinstance (wq , TensorQuantizer ) and wq .is_enabled :
435- if getattr (wq , "_calibrator" , None ) is not None :
436- weight_quantizers .append ((parent_module , (param_name , expert_idx ), wq ))
429+ # Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
430+ if _qfe_cls is not None and isinstance (parent_module , _qfe_cls ):
431+ for param_name , param in parent_module .named_parameters (recurse = False ):
432+ qlist = getattr (parent_module , f"{ param_name } _weight_quantizers" , None )
433+ if not isinstance (qlist , nn .ModuleList ):
434+ continue
435+ if len (qlist ) != param .shape [0 ]:
436+ warnings .warn (
437+ f"Skipping { param_name } _weight_quantizers: list length { len (qlist )} "
438+ f"does not match parameter leading dimension { param .shape [0 ]} . "
439+ "This may indicate a misconfigured fused-experts module." ,
440+ stacklevel = 2 ,
441+ )
442+ continue
443+ for expert_idx , wq in enumerate (qlist ):
444+ if isinstance (wq , TensorQuantizer ) and wq .is_enabled :
445+ if getattr (wq , "_calibrator" , None ) is not None :
446+ weight_quantizers .append ((parent_module , (param_name , expert_idx ), wq ))
437447
438448 seen_modules .add (parent_module )
439449
450+ # Warn about enabled weight quantizers that weren't scheduled for MSE calibration.
451+ picked_ids = {id (wq ) for _ , _ , wq in weight_quantizers }
452+
453+ def _is_active_unpicked (q : Any ) -> bool :
454+ return (
455+ isinstance (q , TensorQuantizer )
456+ and q .is_enabled
457+ and getattr (q , "_calibrator" , None ) is not None
458+ and id (q ) not in picked_ids
459+ )
460+
461+ missed : list [str ] = []
462+ for mod_name , module in name_to_module .items ():
463+ for attr_name , attr in module ._modules .items ():
464+ if isinstance (attr , TensorQuantizer ) and attr_name .endswith ("weight_quantizer" ):
465+ if _is_active_unpicked (attr ):
466+ missed .append (f"{ mod_name } .{ attr_name } " )
467+ elif isinstance (attr , nn .ModuleList ) and attr_name .endswith ("_weight_quantizers" ):
468+ for i , wq in enumerate (attr ):
469+ if _is_active_unpicked (wq ):
470+ missed .append (f"{ mod_name } .{ attr_name } [{ i } ]" )
471+ if missed :
472+ warnings .warn (
473+ f"MSE weight calibration: { len (missed )} weight quantizer(s) are enabled but were "
474+ f"not scheduled for calibration and will retain max-calibration amax values. "
475+ f"First { min (5 , len (missed ))} : { missed [:5 ]} " ,
476+ stacklevel = 2 ,
477+ )
478+
440479 # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
441480 # This prevents massive memory accumulation seen in large models
442481 for idx , (parent_module , weight_name , weight_quantizer ) in enumerate (
0 commit comments