Skip to content

Commit 325f679

Browse files
[BugFix] Fix Torch.Compile For DeepSeek (vllm-project#12594)
Co-authored-by: simon-mo <[email protected]>
1 parent e3f7ff6 commit 325f679

File tree

1 file changed

+29
-25
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+29
-25
lines changed

vllm/model_executor/layers/quantization/fp8.py

+29-25
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,24 @@ def create_weights(
245245
layer.register_parameter("input_scale", None)
246246

247247
def process_weights_after_loading(self, layer: Module) -> None:
248-
# Block quant doesn't need to process weights after loading
248+
# TODO(rob): refactor block quant into separate class.
249249
if self.block_quant:
250+
assert self.quant_config.activation_scheme == "dynamic"
250251
if current_platform.is_rocm():
251-
weight, weight_scale, _ = \
252+
weight, weight_scale_inv, _ = \
252253
normalize_e4m3fn_to_e4m3fnuz(
253254
weight=layer.weight,
254-
weight_scale=layer.weight_scale_inv,
255-
input_scale=layer.input_scale)
256-
layer.weight = Parameter(weight, requires_grad=False)
257-
layer.weight_scale_inv = Parameter(weight_scale,
258-
requires_grad=False)
255+
weight_scale=layer.weight_scale_inv)
256+
else:
257+
weight = layer.weight.data
258+
weight_scale_inv = layer.weight_scale_inv.data
259+
260+
# Torch.compile cannot use Parameter subclasses.
261+
layer.weight = Parameter(weight, requires_grad=False)
262+
layer.weight_scale_inv = Parameter(weight_scale_inv,
263+
requires_grad=False)
259264
return
260-
layer.weight = torch.nn.Parameter(layer.weight.data,
261-
requires_grad=False)
265+
262266
# If checkpoint not serialized fp8, quantize the weights.
263267
if not self.quant_config.is_checkpoint_fp8_serialized:
264268
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
@@ -507,8 +511,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
507511
layer.w2_input_scale = None
508512

509513
def process_weights_after_loading(self, layer: Module) -> None:
510-
# Block quant doesn't need to process weights after loading
514+
# TODO (rob): refactor block quant into separate class.
511515
if self.block_quant:
516+
assert self.quant_config.activation_scheme == "dynamic"
512517
if current_platform.is_rocm():
513518
w13_weight, w13_weight_scale_inv, w13_input_scale = \
514519
normalize_e4m3fn_to_e4m3fnuz(
@@ -518,22 +523,21 @@ def process_weights_after_loading(self, layer: Module) -> None:
518523
normalize_e4m3fn_to_e4m3fnuz(
519524
layer.w2_weight, layer.w2_weight_scale_inv,
520525
layer.w2_input_scale)
521-
# Reset the parameter
522-
layer.w13_weight = torch.nn.Parameter(w13_weight,
523-
requires_grad=False)
524-
layer.w13_weight_scale_inv = torch.nn.Parameter(
525-
w13_weight_scale_inv, requires_grad=False)
526-
if w13_input_scale is not None:
527-
layer.w13_input_scale = torch.nn.Parameter(
528-
w13_input_scale, requires_grad=False)
529-
layer.w2_weight = torch.nn.Parameter(w2_weight,
530-
requires_grad=False)
531-
layer.w2_weight_scale_inv = torch.nn.Parameter(
532-
w2_weight_scale_inv, requires_grad=False)
533-
if w2_input_scale is not None:
534-
layer.w2_input_scale = torch.nn.Parameter(
535-
w2_input_scale, requires_grad=False)
526+
else:
527+
w13_weight = layer.w13_weight.data
528+
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
529+
w2_weight = layer.w2_weight
530+
w2_weight_scale_inv = layer.w2_weight_scale_inv
531+
532+
# torch.compile() cannot use Parameter subclasses.
533+
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
534+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
535+
requires_grad=False)
536+
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
537+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
538+
requires_grad=False)
536539
return
540+
537541
# If checkpoint is fp16, quantize in place.
538542
if not self.quant_config.is_checkpoint_fp8_serialized:
539543
# If rocm, use float8_e4m3fnuz as dtype

0 commit comments

Comments
 (0)