Skip to content

Commit 4041de8

Browse files
kylesayrsJiantaoXu
authored andcommitted
[QeRL] Compose online quantization with quantized reloading (vllm-project#38032)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent b4981f8 commit 4041de8

File tree

10 files changed

+184
-260
lines changed

10 files changed

+184
-260
lines changed

tests/model_executor/model_loader/test_reload.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
148148
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
149149
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
150150
assert add_perp < mul_perp
151+
152+
153+
@pytest.mark.parametrize("tp_size", [2])
154+
@pytest.mark.parametrize(
155+
"base_model,mul_model,add_model,quantization",
156+
[
157+
(
158+
"Qwen/Qwen3-0.6B",
159+
"inference-optimization/Qwen3-0.6B-debug-multiply",
160+
"inference-optimization/Qwen3-0.6B-debug-add",
161+
"fp8",
162+
),
163+
(
164+
"inference-optimization/DeepSeek-V3-debug-empty",
165+
"inference-optimization/DeepSeek-V3-debug-multiply",
166+
"inference-optimization/DeepSeek-V3-debug-add",
167+
"fp8",
168+
),
169+
(
170+
"Qwen/Qwen3-0.6B",
171+
"inference-optimization/Qwen3-0.6B-debug-multiply",
172+
"inference-optimization/Qwen3-0.6B-debug-add",
173+
"mxfp8",
174+
),
175+
# ( TODO: support mxfp4 & mla
176+
# "inference-optimization/DeepSeek-V3-debug-empty",
177+
# "inference-optimization/DeepSeek-V3-debug-multiply",
178+
# "inference-optimization/DeepSeek-V3-debug-add",
179+
# "mxfp8",
180+
# ),
181+
],
182+
)
183+
def test_online_quantize_reload(
184+
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
185+
):
186+
if cuda_device_count_stateless() < tp_size:
187+
pytest.skip(reason="Not enough CUDA devices")
188+
189+
if quantization == "fp8" and not current_platform.supports_fp8():
190+
pytest.skip(reason="Requires FP8 support")
191+
192+
with vllm_runner(
193+
model_name=base_model,
194+
quantization=quantization,
195+
tensor_parallel_size=tp_size,
196+
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
197+
enable_prefix_caching=False,
198+
) as llm:
199+
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
200+
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
201+
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
202+
assert mul_perp < add_perp
203+
204+
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
205+
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
206+
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
207+
assert add_perp < mul_perp

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 27 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373
cutlass_fp8_supported,
7474
normalize_e4m3fn_to_e4m3fnuz,
7575
)
76-
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
76+
from vllm.model_executor.model_loader.reload.layerwise import (
77+
initialize_online_processing,
78+
)
7779
from vllm.model_executor.parameter import (
7880
BlockQuantScaleParameter,
7981
ModelWeightParameter,
@@ -496,8 +498,8 @@ def apply(
496498

497499

498500
class Fp8OnlineLinearMethod(Fp8LinearMethod):
499-
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
500-
and quantized the weights during loading."""
501+
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
502+
and quantizes weights during loading."""
501503

502504
uses_meta_device: bool = True
503505

@@ -519,84 +521,25 @@ def create_weights(
519521
layer.orig_dtype = params_dtype
520522
layer.weight_block_size = None
521523

522-
# WEIGHT
523-
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
524-
# track how many elements we have updated
525-
if not hasattr(layer, "_loaded_numel"):
526-
layer._loaded_numel = 0
527-
528-
# when the first `loaded_weight` is about to be
529-
# loaded to `param`, materialize `param` just-in-time
530-
weight = ModelWeightParameter(
531-
data=torch.empty_like(layer.weight, device=layer._load_device),
532-
input_dim=1,
533-
output_dim=0,
534-
weight_loader=patched_weight_loader,
535-
)
536-
_copy_missing_attrs(layer.weight, weight)
537-
layer.register_parameter("weight", weight)
538-
del layer._load_device
539-
540-
# refresh the reference to `param` to reflect just-in-time
541-
# materialization
542-
param = layer.weight
543-
544-
# load the current weight chunk
545-
copy_numel_counter = CopyNumelCounter()
546-
with copy_numel_counter:
547-
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
548-
layer._loaded_numel += copy_numel_counter.copied_numel
549-
550-
# if we have loaded all of the elements, call
551-
# process_weights_after_loading
552-
target_loaded_numel = layer.weight.numel()
553-
if layer._loaded_numel == target_loaded_numel:
554-
self.process_weights_after_loading(layer)
555-
556-
# Prevent the usual `process_weights_after_loading` call from doing
557-
# anything
558-
layer._already_called_process_weights_after_loading = True
559-
560-
# Note that we keep `layer._loaded_numel` around just in case
561-
# there is logic added to vllm in the future which calls a
562-
# weight loader twice - we do not want to re-initialize in
563-
# that case.
564-
565-
return res
566-
567524
weight = ModelWeightParameter(
568525
data=torch.empty(
569526
output_size_per_partition,
570527
input_size_per_partition,
571-
# materialized just-in-time in `patched_weight_loader`
572-
device="meta",
528+
device="meta", # materialized and processed during loading
573529
dtype=params_dtype,
574530
),
575531
input_dim=1,
576532
output_dim=0,
577-
weight_loader=patched_weight_loader,
533+
weight_loader=weight_loader,
578534
)
579-
# stash the correct device for `patched_weight_loader`
580-
layer._load_device = torch.get_default_device()
581535
layer.register_parameter("weight", weight)
582536

537+
initialize_online_processing(layer)
538+
583539
def process_weights_after_loading(self, layer: Module) -> None:
584540
if getattr(layer, "_already_called_process_weights_after_loading", False):
585541
return
586542

587-
# deferred initialization of randomly initialized weights for the
588-
# `--load_format dummy` feature
589-
if layer.weight.device == torch.device("meta"):
590-
weight = ModelWeightParameter(
591-
data=torch.empty_like(layer.weight, device=layer._load_device),
592-
input_dim=1,
593-
output_dim=0,
594-
weight_loader=layer.weight.weight_loader,
595-
)
596-
_copy_missing_attrs(layer.weight, weight)
597-
layer.register_parameter("weight", weight)
598-
initialize_single_dummy_weight(layer.weight)
599-
600543
# TODO(future): support block_quant in online quant path
601544
assert not self.block_quant
602545

@@ -845,9 +788,6 @@ def _setup_kernel(
845788
)
846789

847790
def process_weights_after_loading(self, layer: Module) -> None:
848-
if getattr(layer, "_already_called_process_weights_after_loading", False):
849-
return
850-
851791
# Allow for accessing weights and scales in standard way.
852792
w13 = layer.w13_weight
853793
w2 = layer.w2_weight
@@ -892,9 +832,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
892832
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
893833
)
894834

895-
# Prevent duplicate processing (e.g., during weight reload)
896-
layer._already_called_process_weights_after_loading = True
897-
898835
def maybe_make_prepare_finalize(
899836
self,
900837
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -1013,86 +950,12 @@ def create_weights(
1013950
layer.orig_dtype = params_dtype
1014951
layer.weight_block_size = None
1015952

1016-
# We are doing online quantization, patch the weight loaded
1017-
# to call `process_weights_after_loading` in a streaming fashion
1018-
# as soon as the last weight chunk is loaded.
1019-
weight_loader = extra_weight_attrs["weight_loader"]
1020-
# create a new holder to prevent modifying behavior of any other
1021-
# objects which might depend on the old one
1022-
new_extra_weight_attrs = extra_weight_attrs
1023-
1024-
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
1025-
# add a counter to track how many elements we have updated
1026-
if not hasattr(layer, "_loaded_numel"):
1027-
layer._loaded_numel = 0
1028-
1029-
# save the ids of original w13 and w2 so that we can
1030-
# distinguish which one `param` should map to further
1031-
# down in this file
1032-
layer._w13_weight_orig_id = id(layer.w13_weight)
1033-
layer._w2_weight_orig_id = id(layer.w2_weight)
1034-
1035-
# when the first `loaded_weight` is about to be
1036-
# loaded to `param`, materialize `param` just-in-time
1037-
1038-
w13_weight = torch.nn.Parameter(
1039-
torch.empty_like(layer.w13_weight, device=layer._load_device),
1040-
requires_grad=False,
1041-
)
1042-
set_weight_attrs(w13_weight, extra_weight_attrs)
1043-
_copy_missing_attrs(layer.w13_weight, w13_weight)
1044-
layer.register_parameter("w13_weight", w13_weight)
1045-
1046-
w2_weight = torch.nn.Parameter(
1047-
torch.empty_like(layer.w2_weight, device=layer._load_device),
1048-
requires_grad=False,
1049-
)
1050-
set_weight_attrs(w2_weight, extra_weight_attrs)
1051-
_copy_missing_attrs(layer.w2_weight, w2_weight)
1052-
layer.register_parameter("w2_weight", w2_weight)
1053-
del layer._load_device
1054-
1055-
# refresh the reference to `param` to reflect just-in-time
1056-
# materialization
1057-
if id(param) == layer._w13_weight_orig_id:
1058-
param = layer.w13_weight
1059-
elif id(param) == layer._w2_weight_orig_id:
1060-
param = layer.w2_weight
1061-
1062-
# load the current weight chunk
1063-
copy_numel_counter = CopyNumelCounter()
1064-
with copy_numel_counter:
1065-
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
1066-
layer._loaded_numel += copy_numel_counter.copied_numel
1067-
1068-
# if we have loaded all of the elements, call
1069-
# process_weights_after_loading
1070-
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
1071-
if layer._loaded_numel == target_loaded_numel:
1072-
self.process_weights_after_loading(layer)
1073-
1074-
# Prevent the usual `process_weights_after_loading` call
1075-
# from doing anything
1076-
layer._already_called_process_weights_after_loading = True
1077-
1078-
# Note that we keep `layer._loaded_numel`,
1079-
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
1080-
# around because if EP is on, weight loaders for non-local
1081-
# experts will run but not actually copy any elements, and we
1082-
# need to not re-initialize in that case.
1083-
1084-
return res
1085-
1086-
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
1087-
extra_weight_attrs = new_extra_weight_attrs
1088-
1089953
# WEIGHTS
1090954
w13_weight = torch.nn.Parameter(
1091955
torch.empty(
1092956
num_experts,
1093957
2 * intermediate_size_per_partition,
1094958
hidden_size,
1095-
# materialized just-in-time in `patched_weight_loader`
1096959
device="meta",
1097960
dtype=params_dtype,
1098961
),
@@ -1106,91 +969,53 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
1106969
num_experts,
1107970
hidden_size,
1108971
intermediate_size_per_partition,
1109-
# materialized just-in-time in `patched_weight_loader`
1110-
device="meta",
972+
device="meta", # materialized and processed during loading
1111973
dtype=params_dtype,
1112974
),
1113975
requires_grad=False,
1114976
)
1115977
layer.register_parameter("w2_weight", w2_weight)
1116978
set_weight_attrs(w2_weight, extra_weight_attrs)
1117-
# stash the correct device for `patched_weight_loader`
1118-
layer._load_device = torch.get_default_device()
1119979

1120980
# BIASES (for models like GPT-OSS that have biased MoE)
1121981
if self.moe.has_bias:
1122-
# Use the original weight_loader (not patched) for biases
1123-
orig_extra_weight_attrs = dict(extra_weight_attrs)
1124-
orig_extra_weight_attrs["weight_loader"] = weight_loader
1125982
w13_bias = torch.nn.Parameter(
1126983
torch.zeros(
1127984
num_experts,
1128985
2 * intermediate_size_per_partition,
986+
device="meta", # materialized and processed during loading
1129987
dtype=layer.orig_dtype,
1130988
),
1131989
requires_grad=False,
1132990
)
1133991
layer.register_parameter("w13_bias", w13_bias)
1134-
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
992+
set_weight_attrs(w13_bias, extra_weight_attrs)
993+
1135994
w2_bias = torch.nn.Parameter(
1136-
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
995+
torch.zeros(
996+
num_experts,
997+
hidden_size,
998+
device="meta", # materialized and processed during loading
999+
dtype=layer.orig_dtype,
1000+
),
11371001
requires_grad=False,
11381002
)
11391003
layer.register_parameter("w2_bias", w2_bias)
1140-
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
1141-
1142-
# WEIGHT_SCALES
1143-
# Allocate 2 scales for w1 and w3 respectively.
1144-
# They will be combined to a single scale after weight loading.
1145-
w13_weight_scale = torch.nn.Parameter(
1146-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
1147-
)
1148-
w2_weight_scale = torch.nn.Parameter(
1149-
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
1150-
)
1151-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
1152-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
1153-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1154-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1004+
set_weight_attrs(w2_bias, extra_weight_attrs)
11551005

1156-
layer.w13_input_scale = None
1157-
layer.w2_input_scale = None
1006+
initialize_online_processing(layer)
11581007

11591008
def process_weights_after_loading(self, layer: Module) -> None:
11601009
if getattr(layer, "_already_called_process_weights_after_loading", False):
11611010
return
11621011

1163-
# deferred initialization of randomly initialized weights for the
1164-
# `--load_format dummy` feature
1165-
if layer.w13_weight.device == torch.device("meta"):
1166-
w13_weight = torch.nn.Parameter(
1167-
torch.empty_like(layer.w13_weight, device=layer._load_device),
1168-
requires_grad=False,
1169-
)
1170-
set_weight_attrs(
1171-
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
1172-
)
1173-
_copy_missing_attrs(layer.w13_weight, w13_weight)
1174-
layer.register_parameter("w13_weight", w13_weight)
1175-
initialize_single_dummy_weight(layer.w13_weight)
1176-
if layer.w2_weight.device == torch.device("meta"):
1177-
w2_weight = torch.nn.Parameter(
1178-
torch.empty_like(layer.w2_weight, device=layer._load_device),
1179-
requires_grad=False,
1180-
)
1181-
set_weight_attrs(
1182-
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
1183-
)
1184-
_copy_missing_attrs(layer.w2_weight, w2_weight)
1185-
layer.register_parameter("w2_weight", w2_weight)
1186-
initialize_single_dummy_weight(layer.w2_weight)
1187-
1188-
# If checkpoint is fp16, quantize in place.
11891012
fp8_dtype = current_platform.fp8_dtype()
11901013
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
11911014
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1192-
w13_scale = layer.w13_weight_scale
1193-
w2_scale = layer.w2_weight_scale
1015+
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
1016+
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
1017+
layer.w13_input_scale = None
1018+
layer.w2_input_scale = None
11941019

11951020
for expert in range(layer.local_num_experts):
11961021
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
@@ -1207,8 +1032,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
12071032
w2,
12081033
w13_scale,
12091034
w2_scale,
1210-
layer.w13_input_scale,
1211-
layer.w2_input_scale,
1035+
w13_input_scale=layer.w13_input_scale,
1036+
w2_input_scale=layer.w2_input_scale,
12121037
)
12131038

12141039
# Prevent duplicate processing (e.g., during weight reload)

0 commit comments

Comments
 (0)