7373 cutlass_fp8_supported ,
7474 normalize_e4m3fn_to_e4m3fnuz ,
7575)
76- from vllm .model_executor .model_loader .weight_utils import initialize_single_dummy_weight
76+ from vllm .model_executor .model_loader .reload .layerwise import (
77+ initialize_online_processing ,
78+ )
7779from vllm .model_executor .parameter import (
7880 BlockQuantScaleParameter ,
7981 ModelWeightParameter ,
@@ -496,8 +498,8 @@ def apply(
496498
497499
498500class Fp8OnlineLinearMethod (Fp8LinearMethod ):
499- """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
500- and quantized the weights during loading."""
501+ """Online version of Fp8LinearMethod which loads a full precision checkpoint
502+ and quantizes weights during loading."""
501503
502504 uses_meta_device : bool = True
503505
@@ -519,84 +521,25 @@ def create_weights(
519521 layer .orig_dtype = params_dtype
520522 layer .weight_block_size = None
521523
522- # WEIGHT
523- def patched_weight_loader (param , loaded_weight , * args , ** kwargs ):
524- # track how many elements we have updated
525- if not hasattr (layer , "_loaded_numel" ):
526- layer ._loaded_numel = 0
527-
528- # when the first `loaded_weight` is about to be
529- # loaded to `param`, materialize `param` just-in-time
530- weight = ModelWeightParameter (
531- data = torch .empty_like (layer .weight , device = layer ._load_device ),
532- input_dim = 1 ,
533- output_dim = 0 ,
534- weight_loader = patched_weight_loader ,
535- )
536- _copy_missing_attrs (layer .weight , weight )
537- layer .register_parameter ("weight" , weight )
538- del layer ._load_device
539-
540- # refresh the reference to `param` to reflect just-in-time
541- # materialization
542- param = layer .weight
543-
544- # load the current weight chunk
545- copy_numel_counter = CopyNumelCounter ()
546- with copy_numel_counter :
547- res = weight_loader (param , loaded_weight , * args , ** kwargs ) # type: ignore[misc]
548- layer ._loaded_numel += copy_numel_counter .copied_numel
549-
550- # if we have loaded all of the elements, call
551- # process_weights_after_loading
552- target_loaded_numel = layer .weight .numel ()
553- if layer ._loaded_numel == target_loaded_numel :
554- self .process_weights_after_loading (layer )
555-
556- # Prevent the usual `process_weights_after_loading` call from doing
557- # anything
558- layer ._already_called_process_weights_after_loading = True
559-
560- # Note that we keep `layer._loaded_numel` around just in case
561- # there is logic added to vllm in the future which calls a
562- # weight loader twice - we do not want to re-initialize in
563- # that case.
564-
565- return res
566-
567524 weight = ModelWeightParameter (
568525 data = torch .empty (
569526 output_size_per_partition ,
570527 input_size_per_partition ,
571- # materialized just-in-time in `patched_weight_loader`
572- device = "meta" ,
528+ device = "meta" , # materialized and processed during loading
573529 dtype = params_dtype ,
574530 ),
575531 input_dim = 1 ,
576532 output_dim = 0 ,
577- weight_loader = patched_weight_loader ,
533+ weight_loader = weight_loader ,
578534 )
579- # stash the correct device for `patched_weight_loader`
580- layer ._load_device = torch .get_default_device ()
581535 layer .register_parameter ("weight" , weight )
582536
537+ initialize_online_processing (layer )
538+
583539 def process_weights_after_loading (self , layer : Module ) -> None :
584540 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
585541 return
586542
587- # deferred initialization of randomly initialized weights for the
588- # `--load_format dummy` feature
589- if layer .weight .device == torch .device ("meta" ):
590- weight = ModelWeightParameter (
591- data = torch .empty_like (layer .weight , device = layer ._load_device ),
592- input_dim = 1 ,
593- output_dim = 0 ,
594- weight_loader = layer .weight .weight_loader ,
595- )
596- _copy_missing_attrs (layer .weight , weight )
597- layer .register_parameter ("weight" , weight )
598- initialize_single_dummy_weight (layer .weight )
599-
600543 # TODO(future): support block_quant in online quant path
601544 assert not self .block_quant
602545
@@ -845,9 +788,6 @@ def _setup_kernel(
845788 )
846789
847790 def process_weights_after_loading (self , layer : Module ) -> None :
848- if getattr (layer , "_already_called_process_weights_after_loading" , False ):
849- return
850-
851791 # Allow for accessing weights and scales in standard way.
852792 w13 = layer .w13_weight
853793 w2 = layer .w2_weight
@@ -892,9 +832,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
892832 layer , w13 , w2 , w13_scale , w2_scale , w13_input_scale , w2_input_scale
893833 )
894834
895- # Prevent duplicate processing (e.g., during weight reload)
896- layer ._already_called_process_weights_after_loading = True
897-
898835 def maybe_make_prepare_finalize (
899836 self ,
900837 routing_tables : tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | None = None ,
@@ -1013,86 +950,12 @@ def create_weights(
1013950 layer .orig_dtype = params_dtype
1014951 layer .weight_block_size = None
1015952
1016- # We are doing online quantization, patch the weight loaded
1017- # to call `process_weights_after_loading` in a streaming fashion
1018- # as soon as the last weight chunk is loaded.
1019- weight_loader = extra_weight_attrs ["weight_loader" ]
1020- # create a new holder to prevent modifying behavior of any other
1021- # objects which might depend on the old one
1022- new_extra_weight_attrs = extra_weight_attrs
1023-
1024- def patched_weight_loader (param , loaded_weight , * args , ** kwargs ):
1025- # add a counter to track how many elements we have updated
1026- if not hasattr (layer , "_loaded_numel" ):
1027- layer ._loaded_numel = 0
1028-
1029- # save the ids of original w13 and w2 so that we can
1030- # distinguish which one `param` should map to further
1031- # down in this file
1032- layer ._w13_weight_orig_id = id (layer .w13_weight )
1033- layer ._w2_weight_orig_id = id (layer .w2_weight )
1034-
1035- # when the first `loaded_weight` is about to be
1036- # loaded to `param`, materialize `param` just-in-time
1037-
1038- w13_weight = torch .nn .Parameter (
1039- torch .empty_like (layer .w13_weight , device = layer ._load_device ),
1040- requires_grad = False ,
1041- )
1042- set_weight_attrs (w13_weight , extra_weight_attrs )
1043- _copy_missing_attrs (layer .w13_weight , w13_weight )
1044- layer .register_parameter ("w13_weight" , w13_weight )
1045-
1046- w2_weight = torch .nn .Parameter (
1047- torch .empty_like (layer .w2_weight , device = layer ._load_device ),
1048- requires_grad = False ,
1049- )
1050- set_weight_attrs (w2_weight , extra_weight_attrs )
1051- _copy_missing_attrs (layer .w2_weight , w2_weight )
1052- layer .register_parameter ("w2_weight" , w2_weight )
1053- del layer ._load_device
1054-
1055- # refresh the reference to `param` to reflect just-in-time
1056- # materialization
1057- if id (param ) == layer ._w13_weight_orig_id :
1058- param = layer .w13_weight
1059- elif id (param ) == layer ._w2_weight_orig_id :
1060- param = layer .w2_weight
1061-
1062- # load the current weight chunk
1063- copy_numel_counter = CopyNumelCounter ()
1064- with copy_numel_counter :
1065- res = weight_loader (param , loaded_weight , * args , ** kwargs ) # type: ignore[misc]
1066- layer ._loaded_numel += copy_numel_counter .copied_numel
1067-
1068- # if we have loaded all of the elements, call
1069- # process_weights_after_loading
1070- target_loaded_numel = layer .w13_weight .numel () + layer .w2_weight .numel ()
1071- if layer ._loaded_numel == target_loaded_numel :
1072- self .process_weights_after_loading (layer )
1073-
1074- # Prevent the usual `process_weights_after_loading` call
1075- # from doing anything
1076- layer ._already_called_process_weights_after_loading = True
1077-
1078- # Note that we keep `layer._loaded_numel`,
1079- # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
1080- # around because if EP is on, weight loaders for non-local
1081- # experts will run but not actually copy any elements, and we
1082- # need to not re-initialize in that case.
1083-
1084- return res
1085-
1086- new_extra_weight_attrs ["weight_loader" ] = patched_weight_loader
1087- extra_weight_attrs = new_extra_weight_attrs
1088-
1089953 # WEIGHTS
1090954 w13_weight = torch .nn .Parameter (
1091955 torch .empty (
1092956 num_experts ,
1093957 2 * intermediate_size_per_partition ,
1094958 hidden_size ,
1095- # materialized just-in-time in `patched_weight_loader`
1096959 device = "meta" ,
1097960 dtype = params_dtype ,
1098961 ),
@@ -1106,91 +969,53 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
1106969 num_experts ,
1107970 hidden_size ,
1108971 intermediate_size_per_partition ,
1109- # materialized just-in-time in `patched_weight_loader`
1110- device = "meta" ,
972+ device = "meta" , # materialized and processed during loading
1111973 dtype = params_dtype ,
1112974 ),
1113975 requires_grad = False ,
1114976 )
1115977 layer .register_parameter ("w2_weight" , w2_weight )
1116978 set_weight_attrs (w2_weight , extra_weight_attrs )
1117- # stash the correct device for `patched_weight_loader`
1118- layer ._load_device = torch .get_default_device ()
1119979
1120980 # BIASES (for models like GPT-OSS that have biased MoE)
1121981 if self .moe .has_bias :
1122- # Use the original weight_loader (not patched) for biases
1123- orig_extra_weight_attrs = dict (extra_weight_attrs )
1124- orig_extra_weight_attrs ["weight_loader" ] = weight_loader
1125982 w13_bias = torch .nn .Parameter (
1126983 torch .zeros (
1127984 num_experts ,
1128985 2 * intermediate_size_per_partition ,
986+ device = "meta" , # materialized and processed during loading
1129987 dtype = layer .orig_dtype ,
1130988 ),
1131989 requires_grad = False ,
1132990 )
1133991 layer .register_parameter ("w13_bias" , w13_bias )
1134- set_weight_attrs (w13_bias , orig_extra_weight_attrs )
992+ set_weight_attrs (w13_bias , extra_weight_attrs )
993+
1135994 w2_bias = torch .nn .Parameter (
1136- torch .zeros (num_experts , hidden_size , dtype = layer .orig_dtype ),
995+ torch .zeros (
996+ num_experts ,
997+ hidden_size ,
998+ device = "meta" , # materialized and processed during loading
999+ dtype = layer .orig_dtype ,
1000+ ),
11371001 requires_grad = False ,
11381002 )
11391003 layer .register_parameter ("w2_bias" , w2_bias )
1140- set_weight_attrs (w2_bias , orig_extra_weight_attrs )
1141-
1142- # WEIGHT_SCALES
1143- # Allocate 2 scales for w1 and w3 respectively.
1144- # They will be combined to a single scale after weight loading.
1145- w13_weight_scale = torch .nn .Parameter (
1146- torch .ones (num_experts , dtype = torch .float32 ), requires_grad = False
1147- )
1148- w2_weight_scale = torch .nn .Parameter (
1149- torch .ones (num_experts , dtype = torch .float32 ), requires_grad = False
1150- )
1151- layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
1152- layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
1153- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
1154- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
1004+ set_weight_attrs (w2_bias , extra_weight_attrs )
11551005
1156- layer .w13_input_scale = None
1157- layer .w2_input_scale = None
1006+ initialize_online_processing (layer )
11581007
11591008 def process_weights_after_loading (self , layer : Module ) -> None :
11601009 if getattr (layer , "_already_called_process_weights_after_loading" , False ):
11611010 return
11621011
1163- # deferred initialization of randomly initialized weights for the
1164- # `--load_format dummy` feature
1165- if layer .w13_weight .device == torch .device ("meta" ):
1166- w13_weight = torch .nn .Parameter (
1167- torch .empty_like (layer .w13_weight , device = layer ._load_device ),
1168- requires_grad = False ,
1169- )
1170- set_weight_attrs (
1171- w13_weight , {"weight_loader" : layer .w13_weight .weight_loader }
1172- )
1173- _copy_missing_attrs (layer .w13_weight , w13_weight )
1174- layer .register_parameter ("w13_weight" , w13_weight )
1175- initialize_single_dummy_weight (layer .w13_weight )
1176- if layer .w2_weight .device == torch .device ("meta" ):
1177- w2_weight = torch .nn .Parameter (
1178- torch .empty_like (layer .w2_weight , device = layer ._load_device ),
1179- requires_grad = False ,
1180- )
1181- set_weight_attrs (
1182- w2_weight , {"weight_loader" : layer .w2_weight .weight_loader }
1183- )
1184- _copy_missing_attrs (layer .w2_weight , w2_weight )
1185- layer .register_parameter ("w2_weight" , w2_weight )
1186- initialize_single_dummy_weight (layer .w2_weight )
1187-
1188- # If checkpoint is fp16, quantize in place.
11891012 fp8_dtype = current_platform .fp8_dtype ()
11901013 w13 = torch .empty_like (layer .w13_weight , dtype = fp8_dtype )
11911014 w2 = torch .empty_like (layer .w2_weight , dtype = fp8_dtype )
1192- w13_scale = layer .w13_weight_scale
1193- w2_scale = layer .w2_weight_scale
1015+ w13_scale = torch .ones (layer .num_experts , dtype = torch .float32 )
1016+ w2_scale = torch .ones (layer .num_experts , dtype = torch .float32 )
1017+ layer .w13_input_scale = None
1018+ layer .w2_input_scale = None
11941019
11951020 for expert in range (layer .local_num_experts ):
11961021 w13 [expert , :, :], w13_scale [expert ] = ops .scaled_fp8_quant (
@@ -1207,8 +1032,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
12071032 w2 ,
12081033 w13_scale ,
12091034 w2_scale ,
1210- layer .w13_input_scale ,
1211- layer .w2_input_scale ,
1035+ w13_input_scale = layer .w13_input_scale ,
1036+ w2_input_scale = layer .w2_input_scale ,
12121037 )
12131038
12141039 # Prevent duplicate processing (e.g., during weight reload)
0 commit comments