Skip to content

Commit fffdd56

Browse files
author
Zhewen Li
committed
Revert "[QeRL] Compose online quantization with quantized reloading (vllm-project#38032)"
This reverts commit 648edcf.
1 parent 5b8c30d commit fffdd56

File tree

10 files changed

+254
-180
lines changed

10 files changed

+254
-180
lines changed

tests/model_executor/model_loader/test_reload.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -148,60 +148,3 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
148148
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
149149
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
150150
assert add_perp < mul_perp
151-
152-
153-
@pytest.mark.parametrize("tp_size", [2])
154-
@pytest.mark.parametrize(
155-
"base_model,mul_model,add_model,quantization",
156-
[
157-
(
158-
"Qwen/Qwen3-0.6B",
159-
"inference-optimization/Qwen3-0.6B-debug-multiply",
160-
"inference-optimization/Qwen3-0.6B-debug-add",
161-
"fp8",
162-
),
163-
(
164-
"inference-optimization/DeepSeek-V3-debug-empty",
165-
"inference-optimization/DeepSeek-V3-debug-multiply",
166-
"inference-optimization/DeepSeek-V3-debug-add",
167-
"fp8",
168-
),
169-
(
170-
"Qwen/Qwen3-0.6B",
171-
"inference-optimization/Qwen3-0.6B-debug-multiply",
172-
"inference-optimization/Qwen3-0.6B-debug-add",
173-
"mxfp8",
174-
),
175-
# ( TODO: support mxfp4 & mla
176-
# "inference-optimization/DeepSeek-V3-debug-empty",
177-
# "inference-optimization/DeepSeek-V3-debug-multiply",
178-
# "inference-optimization/DeepSeek-V3-debug-add",
179-
# "mxfp8",
180-
# ),
181-
],
182-
)
183-
def test_online_quantize_reload(
184-
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
185-
):
186-
if cuda_device_count_stateless() < tp_size:
187-
pytest.skip(reason="Not enough CUDA devices")
188-
189-
if quantization == "fp8" and not current_platform.supports_fp8():
190-
pytest.skip(reason="Requires FP8 support")
191-
192-
with vllm_runner(
193-
model_name=base_model,
194-
quantization=quantization,
195-
tensor_parallel_size=tp_size,
196-
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
197-
enable_prefix_caching=False,
198-
) as llm:
199-
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
200-
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
201-
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
202-
assert mul_perp < add_perp
203-
204-
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
205-
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
206-
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
207-
assert add_perp < mul_perp

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 202 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@
7373
cutlass_fp8_supported,
7474
normalize_e4m3fn_to_e4m3fnuz,
7575
)
76-
from vllm.model_executor.model_loader.reload.layerwise import (
77-
initialize_online_processing,
78-
)
76+
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
7977
from vllm.model_executor.parameter import (
8078
BlockQuantScaleParameter,
8179
ModelWeightParameter,
@@ -498,8 +496,8 @@ def apply(
498496

499497

500498
class Fp8OnlineLinearMethod(Fp8LinearMethod):
501-
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
502-
and quantizes weights during loading."""
499+
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
500+
and quantized the weights during loading."""
503501

504502
uses_meta_device: bool = True
505503

@@ -521,25 +519,84 @@ def create_weights(
521519
layer.orig_dtype = params_dtype
522520
layer.weight_block_size = None
523521

522+
# WEIGHT
523+
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
524+
# track how many elements we have updated
525+
if not hasattr(layer, "_loaded_numel"):
526+
layer._loaded_numel = 0
527+
528+
# when the first `loaded_weight` is about to be
529+
# loaded to `param`, materialize `param` just-in-time
530+
weight = ModelWeightParameter(
531+
data=torch.empty_like(layer.weight, device=layer._load_device),
532+
input_dim=1,
533+
output_dim=0,
534+
weight_loader=patched_weight_loader,
535+
)
536+
_copy_missing_attrs(layer.weight, weight)
537+
layer.register_parameter("weight", weight)
538+
del layer._load_device
539+
540+
# refresh the reference to `param` to reflect just-in-time
541+
# materialization
542+
param = layer.weight
543+
544+
# load the current weight chunk
545+
copy_numel_counter = CopyNumelCounter()
546+
with copy_numel_counter:
547+
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
548+
layer._loaded_numel += copy_numel_counter.copied_numel
549+
550+
# if we have loaded all of the elements, call
551+
# process_weights_after_loading
552+
target_loaded_numel = layer.weight.numel()
553+
if layer._loaded_numel == target_loaded_numel:
554+
self.process_weights_after_loading(layer)
555+
556+
# Prevent the usual `process_weights_after_loading` call from doing
557+
# anything
558+
layer._already_called_process_weights_after_loading = True
559+
560+
# Note that we keep `layer._loaded_numel` around just in case
561+
# there is logic added to vllm in the future which calls a
562+
# weight loader twice - we do not want to re-initialize in
563+
# that case.
564+
565+
return res
566+
524567
weight = ModelWeightParameter(
525568
data=torch.empty(
526569
output_size_per_partition,
527570
input_size_per_partition,
528-
device="meta", # materialized and processed during loading
571+
# materialized just-in-time in `patched_weight_loader`
572+
device="meta",
529573
dtype=params_dtype,
530574
),
531575
input_dim=1,
532576
output_dim=0,
533-
weight_loader=weight_loader,
577+
weight_loader=patched_weight_loader,
534578
)
579+
# stash the correct device for `patched_weight_loader`
580+
layer._load_device = torch.get_default_device()
535581
layer.register_parameter("weight", weight)
536582

537-
initialize_online_processing(layer)
538-
539583
def process_weights_after_loading(self, layer: Module) -> None:
540584
if getattr(layer, "_already_called_process_weights_after_loading", False):
541585
return
542586

587+
# deferred initialization of randomly initialized weights for the
588+
# `--load_format dummy` feature
589+
if layer.weight.device == torch.device("meta"):
590+
weight = ModelWeightParameter(
591+
data=torch.empty_like(layer.weight, device=layer._load_device),
592+
input_dim=1,
593+
output_dim=0,
594+
weight_loader=layer.weight.weight_loader,
595+
)
596+
_copy_missing_attrs(layer.weight, weight)
597+
layer.register_parameter("weight", weight)
598+
initialize_single_dummy_weight(layer.weight)
599+
543600
# TODO(future): support block_quant in online quant path
544601
assert not self.block_quant
545602

@@ -788,6 +845,9 @@ def _setup_kernel(
788845
)
789846

790847
def process_weights_after_loading(self, layer: Module) -> None:
848+
if getattr(layer, "_already_called_process_weights_after_loading", False):
849+
return
850+
791851
# Allow for accessing weights and scales in standard way.
792852
w13 = layer.w13_weight
793853
w2 = layer.w2_weight
@@ -832,6 +892,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
832892
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
833893
)
834894

895+
# Prevent duplicate processing (e.g., during weight reload)
896+
layer._already_called_process_weights_after_loading = True
897+
835898
def maybe_make_prepare_finalize(
836899
self,
837900
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -950,12 +1013,86 @@ def create_weights(
9501013
layer.orig_dtype = params_dtype
9511014
layer.weight_block_size = None
9521015

1016+
# We are doing online quantization, patch the weight loaded
1017+
# to call `process_weights_after_loading` in a streaming fashion
1018+
# as soon as the last weight chunk is loaded.
1019+
weight_loader = extra_weight_attrs["weight_loader"]
1020+
# create a new holder to prevent modifying behavior of any other
1021+
# objects which might depend on the old one
1022+
new_extra_weight_attrs = extra_weight_attrs
1023+
1024+
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
1025+
# add a counter to track how many elements we have updated
1026+
if not hasattr(layer, "_loaded_numel"):
1027+
layer._loaded_numel = 0
1028+
1029+
# save the ids of original w13 and w2 so that we can
1030+
# distinguish which one `param` should map to further
1031+
# down in this file
1032+
layer._w13_weight_orig_id = id(layer.w13_weight)
1033+
layer._w2_weight_orig_id = id(layer.w2_weight)
1034+
1035+
# when the first `loaded_weight` is about to be
1036+
# loaded to `param`, materialize `param` just-in-time
1037+
1038+
w13_weight = torch.nn.Parameter(
1039+
torch.empty_like(layer.w13_weight, device=layer._load_device),
1040+
requires_grad=False,
1041+
)
1042+
set_weight_attrs(w13_weight, extra_weight_attrs)
1043+
_copy_missing_attrs(layer.w13_weight, w13_weight)
1044+
layer.register_parameter("w13_weight", w13_weight)
1045+
1046+
w2_weight = torch.nn.Parameter(
1047+
torch.empty_like(layer.w2_weight, device=layer._load_device),
1048+
requires_grad=False,
1049+
)
1050+
set_weight_attrs(w2_weight, extra_weight_attrs)
1051+
_copy_missing_attrs(layer.w2_weight, w2_weight)
1052+
layer.register_parameter("w2_weight", w2_weight)
1053+
del layer._load_device
1054+
1055+
# refresh the reference to `param` to reflect just-in-time
1056+
# materialization
1057+
if id(param) == layer._w13_weight_orig_id:
1058+
param = layer.w13_weight
1059+
elif id(param) == layer._w2_weight_orig_id:
1060+
param = layer.w2_weight
1061+
1062+
# load the current weight chunk
1063+
copy_numel_counter = CopyNumelCounter()
1064+
with copy_numel_counter:
1065+
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
1066+
layer._loaded_numel += copy_numel_counter.copied_numel
1067+
1068+
# if we have loaded all of the elements, call
1069+
# process_weights_after_loading
1070+
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
1071+
if layer._loaded_numel == target_loaded_numel:
1072+
self.process_weights_after_loading(layer)
1073+
1074+
# Prevent the usual `process_weights_after_loading` call
1075+
# from doing anything
1076+
layer._already_called_process_weights_after_loading = True
1077+
1078+
# Note that we keep `layer._loaded_numel`,
1079+
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
1080+
# around because if EP is on, weight loaders for non-local
1081+
# experts will run but not actually copy any elements, and we
1082+
# need to not re-initialize in that case.
1083+
1084+
return res
1085+
1086+
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
1087+
extra_weight_attrs = new_extra_weight_attrs
1088+
9531089
# WEIGHTS
9541090
w13_weight = torch.nn.Parameter(
9551091
torch.empty(
9561092
num_experts,
9571093
2 * intermediate_size_per_partition,
9581094
hidden_size,
1095+
# materialized just-in-time in `patched_weight_loader`
9591096
device="meta",
9601097
dtype=params_dtype,
9611098
),
@@ -969,53 +1106,91 @@ def create_weights(
9691106
num_experts,
9701107
hidden_size,
9711108
intermediate_size_per_partition,
972-
device="meta", # materialized and processed during loading
1109+
# materialized just-in-time in `patched_weight_loader`
1110+
device="meta",
9731111
dtype=params_dtype,
9741112
),
9751113
requires_grad=False,
9761114
)
9771115
layer.register_parameter("w2_weight", w2_weight)
9781116
set_weight_attrs(w2_weight, extra_weight_attrs)
1117+
# stash the correct device for `patched_weight_loader`
1118+
layer._load_device = torch.get_default_device()
9791119

9801120
# BIASES (for models like GPT-OSS that have biased MoE)
9811121
if self.moe.has_bias:
1122+
# Use the original weight_loader (not patched) for biases
1123+
orig_extra_weight_attrs = dict(extra_weight_attrs)
1124+
orig_extra_weight_attrs["weight_loader"] = weight_loader
9821125
w13_bias = torch.nn.Parameter(
9831126
torch.zeros(
9841127
num_experts,
9851128
2 * intermediate_size_per_partition,
986-
device="meta", # materialized and processed during loading
9871129
dtype=layer.orig_dtype,
9881130
),
9891131
requires_grad=False,
9901132
)
9911133
layer.register_parameter("w13_bias", w13_bias)
992-
set_weight_attrs(w13_bias, extra_weight_attrs)
993-
1134+
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
9941135
w2_bias = torch.nn.Parameter(
995-
torch.zeros(
996-
num_experts,
997-
hidden_size,
998-
device="meta", # materialized and processed during loading
999-
dtype=layer.orig_dtype,
1000-
),
1136+
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
10011137
requires_grad=False,
10021138
)
10031139
layer.register_parameter("w2_bias", w2_bias)
1004-
set_weight_attrs(w2_bias, extra_weight_attrs)
1140+
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
1141+
1142+
# WEIGHT_SCALES
1143+
# Allocate 2 scales for w1 and w3 respectively.
1144+
# They will be combined to a single scale after weight loading.
1145+
w13_weight_scale = torch.nn.Parameter(
1146+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
1147+
)
1148+
w2_weight_scale = torch.nn.Parameter(
1149+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
1150+
)
1151+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
1152+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
1153+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1154+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
10051155

1006-
initialize_online_processing(layer)
1156+
layer.w13_input_scale = None
1157+
layer.w2_input_scale = None
10071158

10081159
def process_weights_after_loading(self, layer: Module) -> None:
10091160
if getattr(layer, "_already_called_process_weights_after_loading", False):
10101161
return
10111162

1163+
# deferred initialization of randomly initialized weights for the
1164+
# `--load_format dummy` feature
1165+
if layer.w13_weight.device == torch.device("meta"):
1166+
w13_weight = torch.nn.Parameter(
1167+
torch.empty_like(layer.w13_weight, device=layer._load_device),
1168+
requires_grad=False,
1169+
)
1170+
set_weight_attrs(
1171+
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
1172+
)
1173+
_copy_missing_attrs(layer.w13_weight, w13_weight)
1174+
layer.register_parameter("w13_weight", w13_weight)
1175+
initialize_single_dummy_weight(layer.w13_weight)
1176+
if layer.w2_weight.device == torch.device("meta"):
1177+
w2_weight = torch.nn.Parameter(
1178+
torch.empty_like(layer.w2_weight, device=layer._load_device),
1179+
requires_grad=False,
1180+
)
1181+
set_weight_attrs(
1182+
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
1183+
)
1184+
_copy_missing_attrs(layer.w2_weight, w2_weight)
1185+
layer.register_parameter("w2_weight", w2_weight)
1186+
initialize_single_dummy_weight(layer.w2_weight)
1187+
1188+
# If checkpoint is fp16, quantize in place.
10121189
fp8_dtype = current_platform.fp8_dtype()
10131190
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
10141191
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
1015-
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
1016-
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
1017-
layer.w13_input_scale = None
1018-
layer.w2_input_scale = None
1192+
w13_scale = layer.w13_weight_scale
1193+
w2_scale = layer.w2_weight_scale
10191194

10201195
for expert in range(layer.local_num_experts):
10211196
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
@@ -1032,8 +1207,8 @@ def process_weights_after_loading(self, layer: Module) -> None:
10321207
w2,
10331208
w13_scale,
10341209
w2_scale,
1035-
w13_input_scale=layer.w13_input_scale,
1036-
w2_input_scale=layer.w2_input_scale,
1210+
layer.w13_input_scale,
1211+
layer.w2_input_scale,
10371212
)
10381213

10391214
# Prevent duplicate processing (e.g., during weight reload)

0 commit comments

Comments
 (0)