7373 cutlass_fp8_supported ,
7474 normalize_e4m3fn_to_e4m3fnuz ,
7575)
76- from vllm .model_executor .model_loader .reload .layerwise import (
77- initialize_online_processing ,
78- )
76+ from vllm .model_executor .model_loader .weight_utils import initialize_single_dummy_weight
7977from vllm .model_executor .parameter import (
8078 BlockQuantScaleParameter ,
8179 ModelWeightParameter ,
@@ -498,8 +496,8 @@ def apply(
498496
499497
500498class Fp8OnlineLinearMethod (Fp8LinearMethod ):
501- """Online version of Fp8LinearMethod which loads a full precision checkpoint
502- and quantizes weights during loading."""
499+ """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
500+ and quantized the weights during loading."""
503501
504502 uses_meta_device : bool = True
505503
@@ -521,25 +519,84 @@ def create_weights(
521519 layer .orig_dtype = params_dtype
522520 layer .weight_block_size = None
523521
522+ # WEIGHT
523+ def patched_weight_loader (param , loaded_weight , * args , ** kwargs ):
524+ # track how many elements we have updated
525+ if not hasattr (layer , "_loaded_numel" ):
526+ layer ._loaded_numel = 0
527+
528+ # when the first `loaded_weight` is about to be
529+ # loaded to `param`, materialize `param` just-in-time
530+ weight = ModelWeightParameter (
531+ data = torch .empty_like (layer .weight , device = layer ._load_device ),
532+ input_dim = 1 ,
533+ output_dim = 0 ,
534+ weight_loader = patched_weight_loader ,
535+ )
536+ _copy_missing_attrs (layer .weight , weight )
537+ layer .register_parameter ("weight" , weight )
538+ del layer ._load_device
539+
540+ # refresh the reference to `param` to reflect just-in-time
541+ # materialization
542+ param = layer .weight
543+
544+ # load the current weight chunk
545+ copy_numel_counter = CopyNumelCounter ()
546+ with copy_numel_counter :
547+ res = weight_loader (param , loaded_weight , * args , ** kwargs ) # type: ignore[misc]
548+ layer ._loaded_numel += copy_numel_counter .copied_numel
549+
550+ # if we have loaded all of the elements, call
551+ # process_weights_after_loading
552+ target_loaded_numel = layer .weight .numel ()
553+ if layer ._loaded_numel == target_loaded_numel :
554+ self .process_weights_after_loading (layer )
555+
556+ # Prevent the usual `process_weights_after_loading` call from doing
557+ # anything
558+ layer ._already_called_process_weights_after_loading = True
559+
560+ # Note that we keep `layer._loaded_numel` around just in case
561+ # there is logic added to vllm in the future which calls a
562+ # weight loader twice - we do not want to re-initialize in
563+ # that case.
564+
565+ return res
566+
524567 weight = ModelWeightParameter (
525568 data = torch .empty (
526569 output_size_per_partition ,
527570 input_size_per_partition ,
528- device = "meta" , # materialized and processed during loading
571+ # materialized just-in-time in `patched_weight_loader`
572+ device = "meta" ,
529573 dtype = params_dtype ,
530574 ),
531575 input_dim = 1 ,
532576 output_dim = 0 ,
533- weight_loader = weight_loader ,
577+ weight_loader = patched_weight_loader ,
534578 )
579+ # stash the correct device for `patched_weight_loader`
580+ layer ._load_device = torch .get_default_device ()
535581 layer .register_parameter ("weight" , weight )
536582
537- initialize_online_processing (layer )
538-
539583 def process_weights_after_loading (self , layer : Module ) -> None :
540584 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
541585 return
542586
587+ # deferred initialization of randomly initialized weights for the
588+ # `--load_format dummy` feature
589+ if layer .weight .device == torch .device ("meta" ):
590+ weight = ModelWeightParameter (
591+ data = torch .empty_like (layer .weight , device = layer ._load_device ),
592+ input_dim = 1 ,
593+ output_dim = 0 ,
594+ weight_loader = layer .weight .weight_loader ,
595+ )
596+ _copy_missing_attrs (layer .weight , weight )
597+ layer .register_parameter ("weight" , weight )
598+ initialize_single_dummy_weight (layer .weight )
599+
543600 # TODO(future): support block_quant in online quant path
544601 assert not self .block_quant
545602
@@ -788,6 +845,9 @@ def _setup_kernel(
788845 )
789846
790847 def process_weights_after_loading (self , layer : Module ) -> None :
848+ if getattr (layer , "_already_called_process_weights_after_loading" , False ):
849+ return
850+
791851 # Allow for accessing weights and scales in standard way.
792852 w13 = layer .w13_weight
793853 w2 = layer .w2_weight
@@ -832,6 +892,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
832892 layer , w13 , w2 , w13_scale , w2_scale , w13_input_scale , w2_input_scale
833893 )
834894
895+ # Prevent duplicate processing (e.g., during weight reload)
896+ layer ._already_called_process_weights_after_loading = True
897+
835898 def maybe_make_prepare_finalize (
836899 self ,
837900 routing_tables : tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | None = None ,
@@ -950,12 +1013,86 @@ def create_weights(
9501013 layer .orig_dtype = params_dtype
9511014 layer .weight_block_size = None
9521015
1016+ # We are doing online quantization, patch the weight loaded
1017+ # to call `process_weights_after_loading` in a streaming fashion
1018+ # as soon as the last weight chunk is loaded.
1019+ weight_loader = extra_weight_attrs ["weight_loader" ]
1020+ # create a new holder to prevent modifying behavior of any other
1021+ # objects which might depend on the old one
1022+ new_extra_weight_attrs = extra_weight_attrs
1023+
1024+ def patched_weight_loader (param , loaded_weight , * args , ** kwargs ):
1025+ # add a counter to track how many elements we have updated
1026+ if not hasattr (layer , "_loaded_numel" ):
1027+ layer ._loaded_numel = 0
1028+
1029+ # save the ids of original w13 and w2 so that we can
1030+ # distinguish which one `param` should map to further
1031+ # down in this file
1032+ layer ._w13_weight_orig_id = id (layer .w13_weight )
1033+ layer ._w2_weight_orig_id = id (layer .w2_weight )
1034+
1035+ # when the first `loaded_weight` is about to be
1036+ # loaded to `param`, materialize `param` just-in-time
1037+
1038+ w13_weight = torch .nn .Parameter (
1039+ torch .empty_like (layer .w13_weight , device = layer ._load_device ),
1040+ requires_grad = False ,
1041+ )
1042+ set_weight_attrs (w13_weight , extra_weight_attrs )
1043+ _copy_missing_attrs (layer .w13_weight , w13_weight )
1044+ layer .register_parameter ("w13_weight" , w13_weight )
1045+
1046+ w2_weight = torch .nn .Parameter (
1047+ torch .empty_like (layer .w2_weight , device = layer ._load_device ),
1048+ requires_grad = False ,
1049+ )
1050+ set_weight_attrs (w2_weight , extra_weight_attrs )
1051+ _copy_missing_attrs (layer .w2_weight , w2_weight )
1052+ layer .register_parameter ("w2_weight" , w2_weight )
1053+ del layer ._load_device
1054+
1055+ # refresh the reference to `param` to reflect just-in-time
1056+ # materialization
1057+ if id (param ) == layer ._w13_weight_orig_id :
1058+ param = layer .w13_weight
1059+ elif id (param ) == layer ._w2_weight_orig_id :
1060+ param = layer .w2_weight
1061+
1062+ # load the current weight chunk
1063+ copy_numel_counter = CopyNumelCounter ()
1064+ with copy_numel_counter :
1065+ res = weight_loader (param , loaded_weight , * args , ** kwargs ) # type: ignore[misc]
1066+ layer ._loaded_numel += copy_numel_counter .copied_numel
1067+
1068+ # if we have loaded all of the elements, call
1069+ # process_weights_after_loading
1070+ target_loaded_numel = layer .w13_weight .numel () + layer .w2_weight .numel ()
1071+ if layer ._loaded_numel == target_loaded_numel :
1072+ self .process_weights_after_loading (layer )
1073+
1074+ # Prevent the usual `process_weights_after_loading` call
1075+ # from doing anything
1076+ layer ._already_called_process_weights_after_loading = True
1077+
1078+ # Note that we keep `layer._loaded_numel`,
1079+ # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
1080+ # around because if EP is on, weight loaders for non-local
1081+ # experts will run but not actually copy any elements, and we
1082+ # need to not re-initialize in that case.
1083+
1084+ return res
1085+
1086+ new_extra_weight_attrs ["weight_loader" ] = patched_weight_loader
1087+ extra_weight_attrs = new_extra_weight_attrs
1088+
9531089 # WEIGHTS
9541090 w13_weight = torch .nn .Parameter (
9551091 torch .empty (
9561092 num_experts ,
9571093 2 * intermediate_size_per_partition ,
9581094 hidden_size ,
1095+ # materialized just-in-time in `patched_weight_loader`
9591096 device = "meta" ,
9601097 dtype = params_dtype ,
9611098 ),
@@ -969,53 +1106,91 @@ def create_weights(
9691106 num_experts ,
9701107 hidden_size ,
9711108 intermediate_size_per_partition ,
972- device = "meta" , # materialized and processed during loading
1109+ # materialized just-in-time in `patched_weight_loader`
1110+ device = "meta" ,
9731111 dtype = params_dtype ,
9741112 ),
9751113 requires_grad = False ,
9761114 )
9771115 layer .register_parameter ("w2_weight" , w2_weight )
9781116 set_weight_attrs (w2_weight , extra_weight_attrs )
1117+ # stash the correct device for `patched_weight_loader`
1118+ layer ._load_device = torch .get_default_device ()
9791119
9801120 # BIASES (for models like GPT-OSS that have biased MoE)
9811121 if self .moe .has_bias :
1122+ # Use the original weight_loader (not patched) for biases
1123+ orig_extra_weight_attrs = dict (extra_weight_attrs )
1124+ orig_extra_weight_attrs ["weight_loader" ] = weight_loader
9821125 w13_bias = torch .nn .Parameter (
9831126 torch .zeros (
9841127 num_experts ,
9851128 2 * intermediate_size_per_partition ,
986- device = "meta" , # materialized and processed during loading
9871129 dtype = layer .orig_dtype ,
9881130 ),
9891131 requires_grad = False ,
9901132 )
9911133 layer .register_parameter ("w13_bias" , w13_bias )
992- set_weight_attrs (w13_bias , extra_weight_attrs )
993-
1134+ set_weight_attrs (w13_bias , orig_extra_weight_attrs )
9941135 w2_bias = torch .nn .Parameter (
995- torch .zeros (
996- num_experts ,
997- hidden_size ,
998- device = "meta" , # materialized and processed during loading
999- dtype = layer .orig_dtype ,
1000- ),
1136+ torch .zeros (num_experts , hidden_size , dtype = layer .orig_dtype ),
10011137 requires_grad = False ,
10021138 )
10031139 layer .register_parameter ("w2_bias" , w2_bias )
1004- set_weight_attrs (w2_bias , extra_weight_attrs )
1140+ set_weight_attrs (w2_bias , orig_extra_weight_attrs )
1141+
1142+ # WEIGHT_SCALES
1143+ # Allocate 2 scales for w1 and w3 respectively.
1144+ # They will be combined to a single scale after weight loading.
1145+ w13_weight_scale = torch .nn .Parameter (
1146+ torch .ones (num_experts , dtype = torch .float32 ), requires_grad = False
1147+ )
1148+ w2_weight_scale = torch .nn .Parameter (
1149+ torch .ones (num_experts , dtype = torch .float32 ), requires_grad = False
1150+ )
1151+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
1152+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
1153+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
1154+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
10051155
1006- initialize_online_processing (layer )
1156+ layer .w13_input_scale = None
1157+ layer .w2_input_scale = None
10071158
10081159 def process_weights_after_loading (self , layer : Module ) -> None :
10091160 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
10101161 return
10111162
1163+ # deferred initialization of randomly initialized weights for the
1164+ # `--load_format dummy` feature
1165+ if layer .w13_weight .device == torch .device ("meta" ):
1166+ w13_weight = torch .nn .Parameter (
1167+ torch .empty_like (layer .w13_weight , device = layer ._load_device ),
1168+ requires_grad = False ,
1169+ )
1170+ set_weight_attrs (
1171+ w13_weight , {"weight_loader" : layer .w13_weight .weight_loader }
1172+ )
1173+ _copy_missing_attrs (layer .w13_weight , w13_weight )
1174+ layer .register_parameter ("w13_weight" , w13_weight )
1175+ initialize_single_dummy_weight (layer .w13_weight )
1176+ if layer .w2_weight .device == torch .device ("meta" ):
1177+ w2_weight = torch .nn .Parameter (
1178+ torch .empty_like (layer .w2_weight , device = layer ._load_device ),
1179+ requires_grad = False ,
1180+ )
1181+ set_weight_attrs (
1182+ w2_weight , {"weight_loader" : layer .w2_weight .weight_loader }
1183+ )
1184+ _copy_missing_attrs (layer .w2_weight , w2_weight )
1185+ layer .register_parameter ("w2_weight" , w2_weight )
1186+ initialize_single_dummy_weight (layer .w2_weight )
1187+
1188+ # If checkpoint is fp16, quantize in place.
10121189 fp8_dtype = current_platform .fp8_dtype ()
10131190 w13 = torch .empty_like (layer .w13_weight , dtype = fp8_dtype )
10141191 w2 = torch .empty_like (layer .w2_weight , dtype = fp8_dtype )
1015- w13_scale = torch .ones (layer .num_experts , dtype = torch .float32 )
1016- w2_scale = torch .ones (layer .num_experts , dtype = torch .float32 )
1017- layer .w13_input_scale = None
1018- layer .w2_input_scale = None
1192+ w13_scale = layer .w13_weight_scale
1193+ w2_scale = layer .w2_weight_scale
10191194
10201195 for expert in range (layer .local_num_experts ):
10211196 w13 [expert , :, :], w13_scale [expert ] = ops .scaled_fp8_quant (
@@ -1032,8 +1207,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
10321207 w2 ,
10331208 w13_scale ,
10341209 w2_scale ,
1035- w13_input_scale = layer .w13_input_scale ,
1036- w2_input_scale = layer .w2_input_scale ,
1210+ layer .w13_input_scale ,
1211+ layer .w2_input_scale ,
10371212 )
10381213
10391214 # Prevent duplicate processing (e.g., during weight reload)
0 commit comments