forked from MoonshotAI/checkpoint-engine
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm_fp8.patch
More file actions
98 lines (87 loc) · 4.76 KB
/
vllm_fp8.patch
File metadata and controls
98 lines (87 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
From cb707077b0df8bb02adadc7eda63c97a8c6d6c76 Mon Sep 17 00:00:00 2001
From: weixiao-huang <hwx.simle@gmail.com>
Date: Sun, 7 Sep 2025 14:56:43 +0800
Subject: [PATCH] [BugFix] use _wrap_parameter_or_copy instead of using
Parameter and add missing scale attributes
Signed-off-by: weixiao-huang <hwx.simle@gmail.com>
---
.../model_executor/layers/quantization/fp8.py | 40 ++++++++++++++-----
.../layers/quantization/kv_cache.py | 7 ++++
2 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 65e0b7062..46f1885cc 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -65,6 +65,25 @@ def _is_col_major(x: torch.Tensor) -> bool:
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
+def _wrap_parameter_or_copy(layer: torch.nn.Module, name: str,
+ weight: torch.Tensor):
+ layer_weight = getattr(layer, name)
+ if isinstance(layer_weight, Parameter):
+ # If it is already a Parameter, we assume it is the right shape
+ # directly copy it from weight to keep pointer unchanged in CUDA Graph
+ layer_weight.copy_(weight)
+ else:
+ # torch.compile() cannot use Parameter subclasses.
+ # but these weights are already Parameter
+ # so this can be compatible with torch.compile
+ param = Parameter(weight, requires_grad=False)
+ if hasattr(layer_weight, "weight_loader"):
+ # keep the weight_loader attribute to make sure
+ # the weight can be loaded correctly in weight update
+ param.weight_loader = layer_weight.weight_loader
+ setattr(layer, name, param)
+
+
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@@ -387,10 +406,9 @@ class Fp8LinearMethod(LinearMethodBase):
weight = self._maybe_pad_weight(weight)
- # Torch.compile cannot use Parameter subclasses.
- layer.weight = Parameter(weight, requires_grad=False)
- layer.weight_scale_inv = Parameter(weight_scale_inv,
- requires_grad=False)
+ _wrap_parameter_or_copy(layer, "weight", weight)
+ _wrap_parameter_or_copy(layer, "weight_scale_inv",
+ weight_scale_inv)
# If checkpoint not serialized fp8, quantize the weights.
elif not self.quant_config.is_checkpoint_fp8_serialized:
@@ -740,13 +758,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w2_weight = layer.w2_weight
w2_weight_scale_inv = layer.w2_weight_scale_inv
- # torch.compile() cannot use Parameter subclasses.
- layer.w13_weight = Parameter(w13_weight, requires_grad=False)
- layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv,
- requires_grad=False)
- layer.w2_weight = Parameter(w2_weight, requires_grad=False)
- layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
- requires_grad=False)
+ _wrap_parameter_or_copy(layer, "w13_weight", w13_weight)
+ _wrap_parameter_or_copy(layer, "w13_weight_scale_inv",
+ w13_weight_scale_inv)
+ _wrap_parameter_or_copy(layer, "w2_weight", w2_weight)
+ _wrap_parameter_or_copy(layer, "w2_weight_scale_inv",
+ w2_weight_scale_inv)
+
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py
index e5604670f..0483e9b32 100644
--- a/vllm/model_executor/layers/quantization/kv_cache.py
+++ b/vllm/model_executor/layers/quantization/kv_cache.py
@@ -48,6 +48,13 @@ class BaseKVCacheMethod(QuantizeMethodBase):
f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ # update weights may miss these attributes, we create it if not present
+ if not hasattr(layer, "q_scale"):
+ assert not hasattr(layer, "k_scale")
+ assert not hasattr(layer, "v_scale")
+ assert not hasattr(layer, "prob_scale")
+ self.create_weights(layer)
+
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
--
2.39.3 (Apple Git-146)