@@ -77,7 +77,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
7777 # will add async-offload support to your cast and improve performance.
7878 if input is not None :
7979 if dtype is None :
80- dtype = input .dtype
80+ if isinstance (input , QuantizedTensor ):
81+ dtype = input ._layout_params ["orig_dtype" ]
82+ else :
83+ dtype = input .dtype
8184 if bias_dtype is None :
8285 bias_dtype = dtype
8386 if device is None :
@@ -534,18 +537,7 @@ def forward(self, *args, **kwargs):
534537# ==============================================================================
535538# Mixed Precision Operations
536539# ==============================================================================
537- from .quant_ops import QuantizedTensor
538-
539- QUANT_FORMAT_MIXINS = {
540- "float8_e4m3fn" : {
541- "dtype" : torch .float8_e4m3fn ,
542- "layout_type" : "TensorCoreFP8Layout" ,
543- "parameters" : {
544- "weight_scale" : torch .nn .Parameter (torch .zeros ((), dtype = torch .float32 ), requires_grad = False ),
545- "input_scale" : torch .nn .Parameter (torch .zeros ((), dtype = torch .float32 ), requires_grad = False ),
546- }
547- }
548- }
540+ from .quant_ops import QuantizedTensor , QUANT_ALGOS
549541
550542class MixedPrecisionOps (disable_weight_init ):
551543 _layer_quant_config = {}
@@ -596,23 +588,24 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
596588 if quant_format is None :
597589 raise ValueError (f"Unknown quantization format for layer { layer_name } " )
598590
599- mixin = QUANT_FORMAT_MIXINS [quant_format ]
600- self .layout_type = mixin [ "layout_type " ]
591+ qconfig = QUANT_ALGOS [quant_format ]
592+ self .layout_type = qconfig [ "comfy_tensor_layout " ]
601593
602- scale_key = f"{ prefix } weight_scale"
594+ weight_scale_key = f"{ prefix } weight_scale"
603595 layout_params = {
604- 'scale' : state_dict .pop (scale_key , None ),
605- 'orig_dtype' : MixedPrecisionOps ._compute_dtype
596+ 'scale' : state_dict .pop (weight_scale_key , None ),
597+ 'orig_dtype' : MixedPrecisionOps ._compute_dtype ,
598+ 'block_size' : qconfig .get ("group_size" , None ),
606599 }
607600 if layout_params ['scale' ] is not None :
608- manually_loaded_keys .append (scale_key )
601+ manually_loaded_keys .append (weight_scale_key )
609602
610603 self .weight = torch .nn .Parameter (
611- QuantizedTensor (weight .to (device = device , dtype = mixin [ "dtype" ] ), self .layout_type , layout_params ),
604+ QuantizedTensor (weight .to (device = device ), self .layout_type , layout_params ),
612605 requires_grad = False
613606 )
614607
615- for param_name , param_value in mixin ["parameters" ]. items () :
608+ for param_name in qconfig ["parameters" ]:
616609 param_key = f"{ prefix } { param_name } "
617610 _v = state_dict .pop (param_key , None )
618611 if _v is None :
@@ -643,7 +636,7 @@ def forward(self, input, *args, **kwargs):
643636 if (getattr (self , 'layout_type' , None ) is not None and
644637 getattr (self , 'input_scale' , None ) is not None and
645638 not isinstance (input , QuantizedTensor )):
646- input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , fp8_dtype = self .weight .dtype )
639+ input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , dtype = self .weight .dtype )
647640 return self ._forward (input , self .weight , self .bias )
648641
649642
0 commit comments