-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathfp8_gemm.py
More file actions
347 lines (293 loc) · 14.5 KB
/
fp8_gemm.py
File metadata and controls
347 lines (293 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tiny utility to enable vLLM-style FP8 GEMM (W8A8) for arbitrary PyTorch models.
What it does
- Replaces nn.Linear modules with a drop-in module that:
- quantizes activations dynamically per forward call
- quantizes weights lazily on first CUDA forward (and caches them)
- dispatches GEMM via vLLM's Fp8LinearOp (cutlass/flashinfer/torch._scaled_mm)
Notes
- CUDA-only fast path; CPU (and unsupported cases) automatically fall back to
the original nn.Linear.
- Output of vLLM FP8 GEMM is fp16/bf16. If your input is fp32, you can either
keep fp32 (fallback) or enable casting to fp16/bf16 for speed.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional, Literal
import torch
import torch.nn as nn
@dataclass(frozen=True)
class FP8GemmOptions:
# If True, non-fp16/bf16 inputs will be cast to fp16 for the FP8 GEMM path.
# If False, non-fp16/bf16 inputs will fall back to the original nn.Linear.
cast_inputs: bool = True
# If True, the output will be cast back to the original input dtype when
# we cast inputs for the fast path.
cast_output_back: bool = True
# What to do with the original (FP16/BF16) weights after wrapping.
#
# - "keep": keep original weights inside the wrapped module (default).
# - "cpu_offload": move original weights to CPU to save GPU VRAM; keep them
# for potential CPU fallback and/or re-quantization.
# - "discard": do not keep original weights after FP8 weights are
# materialized (lowest steady-state memory). In this mode, CPU fallback
# is not available and weights cannot be re-quantized if the FP8 cache is
# invalidated.
fp16_weight_storage: Literal["keep", "cpu_offload", "discard"] = "discard"
# If True, try to quantize weights immediately while wrapping (only works
# when the original nn.Linear weights are already on CUDA). This enables
# discarding/offloading FP16 weights right away, instead of waiting for the
# first forward pass.
materialize_fp8_on_wrap: bool = True
class FP8Linear(nn.Module):
"""Drop-in replacement for nn.Linear that uses vLLM FP8 GEMM when possible."""
def __init__(self, linear: nn.Linear, *, options: FP8GemmOptions):
super().__init__()
if not isinstance(linear, nn.Linear):
raise TypeError(f"expected nn.Linear, got {type(linear)}")
if options.fp16_weight_storage not in ("keep", "cpu_offload", "discard"):
raise ValueError(
f"invalid fp16_weight_storage={options.fp16_weight_storage!r}; "
"expected one of {'keep','cpu_offload','discard'}"
)
if options.fp16_weight_storage == "discard" and not options.cast_inputs:
# Without FP16 weights, we cannot fall back for non-fp16/bf16 inputs.
raise ValueError(
"fp16_weight_storage='discard' requires cast_inputs=True "
"(otherwise non-fp16/bf16 inputs would need FP16 fallback)."
)
# Keep the original nn.Linear module only in "keep" mode.
self.linear: Optional[nn.Linear] = linear if options.fp16_weight_storage == "keep" else None
self.options = options
# Optional CPU copies for fallback and/or re-quantization.
self._fp16_weight_cpu: Optional[torch.Tensor] = None # [N, K], fp16
self._fp16_bias_cpu: Optional[torch.Tensor] = None # [N], fp16
# Bias for the fast path when we are not keeping the original Linear.
# (In "keep" mode we rely on self.linear.bias.)
self.bias: Optional[nn.Parameter] = None
if options.fp16_weight_storage != "keep":
self.bias = (nn.Parameter(linear.bias.detach().clone())
if linear.bias is not None else None)
# Stash FP16 weights on CPU to immediately free GPU VRAM. We keep
# them until FP8 weights are materialized, then optionally discard.
self._fp16_weight_cpu = linear.weight.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()
if linear.bias is not None:
self._fp16_bias_cpu = linear.bias.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()
# vLLM FP8 GEMM plumbing. We avoid reading vLLM global config, so we
# force pad_output=False to keep this usable as a standalone utility.
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
maybe_create_device_identity()
self._fp8_linear_op = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
pad_output=False,
)
# Lazy weight cache (per-device). Register these as non-persistent
# buffers so module.to()/cpu()/cuda() also migrates the FP8 cache.
self.register_buffer("_fp8_weight", None, persistent=False) # [K, N] view
self.register_buffer("_fp8_weight_scale", None, persistent=False) # scalar or vec
self._weight_cache_device: Optional[torch.device] = None
# Track when weights change (best-effort) in "keep" mode.
# Users can also call invalidate_weight_cache() explicitly after weight updates.
self._last_weight_version: Optional[int] = None
# CUDA-only quant ops live here.
from vllm import _custom_ops as ops
self._ops = ops
@classmethod
def from_linear(cls, linear: nn.Linear, *, options: FP8GemmOptions) -> "FP8Linear":
# In "keep" mode, we keep the original Linear module instance so
# state_dict stays natural (weights/bias remain at linear.weight / linear.bias).
return cls(linear, options=options)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
if self.linear is not None:
src_weight = self.linear.weight.detach()
src_bias = self.linear.bias.detach() if self.linear.bias is not None else None
elif self._fp16_weight_cpu is not None:
src_weight = self._fp16_weight_cpu.detach()
src_bias = self._fp16_bias_cpu.detach() if self._fp16_bias_cpu is not None else None
else:
raise RuntimeError("FP8Linear cannot be deep-copied without an FP16 weight source.")
linear = nn.Linear(
in_features=src_weight.shape[1],
out_features=src_weight.shape[0],
bias=src_bias is not None,
device=src_weight.device,
dtype=src_weight.dtype,
)
linear.weight.data.copy_(src_weight)
if src_bias is not None:
linear.bias.data.copy_(src_bias)
cloned = FP8Linear(linear, options=self.options)
memo[id(self)] = cloned
if self._fp16_weight_cpu is not None:
cloned._fp16_weight_cpu = self._fp16_weight_cpu.detach().clone()
if self._fp16_bias_cpu is not None:
cloned._fp16_bias_cpu = self._fp16_bias_cpu.detach().clone()
if self._fp8_weight is not None:
cloned._fp8_weight = self._fp8_weight.detach().clone()
if self._fp8_weight_scale is not None:
cloned._fp8_weight_scale = self._fp8_weight_scale.detach().clone()
cloned._weight_cache_device = self._weight_cache_device
cloned._last_weight_version = self._last_weight_version
return cloned
def invalidate_weight_cache(self) -> None:
self._fp8_weight = None
self._fp8_weight_scale = None
self._weight_cache_device = None
self._last_weight_version = None
def _cached_fp8_device(self) -> Optional[torch.device]:
if self._fp8_weight is None or self._fp8_weight_scale is None:
return None
if self._fp8_weight.device != self._fp8_weight_scale.device:
return None
return self._fp8_weight.device
def materialize_fp8_weight(self, device: torch.device) -> None:
"""Force FP8 weight materialization on the given device."""
self._maybe_requantize_weight(device)
def _maybe_requantize_weight(self, device: torch.device) -> None:
# Detect weight changes (best-effort) and/or device changes.
cache_device = self._cached_fp8_device()
version: Optional[int] = None
if self.linear is not None:
weight = self.linear.weight
v = getattr(weight, "_version", None)
version = v if isinstance(v, int) else None
if (self._fp8_weight is not None and self._fp8_weight_scale is not None
and cache_device == device
and (version is None or version == self._last_weight_version)):
return
else:
if (self._fp8_weight is not None and self._fp8_weight_scale is not None
and cache_device == device):
return
# vLLM convention for CUTLASS: quantize original [N, K] weight, then
# pass transpose *view* [K, N] into scaled GEMM kernels, which yields
# stride(0)==1 as expected by cutlass_scaled_mm.
if self.linear is not None:
w_src = self.linear.weight.detach()
elif self._fp16_weight_cpu is not None:
w_src = self._fp16_weight_cpu
else:
raise RuntimeError(
"FP8Linear has no FP16 weight source available to (re)quantize. "
"This can happen if fp16_weight_storage='discard' and the FP8 cache was "
"invalidated."
)
w_n_k = w_src.to(device=device, dtype=torch.bfloat16, non_blocking=True).contiguous()
qweight_n_k, w_scale = self._ops.scaled_fp8_quant(w_n_k, scale=None)
self._fp8_weight = qweight_n_k.t()
self._fp8_weight_scale = w_scale
self._weight_cache_device = self._cached_fp8_device()
self._last_weight_version = version
# If requested, discard FP16 weights once FP8 is materialized.
if self.options.fp16_weight_storage == "discard":
self._fp16_weight_cpu = None
self._fp16_bias_cpu = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CPU / non-CUDA: fall back.
if not x.is_cuda:
if self.linear is not None:
return self.linear(x)
if self._fp16_weight_cpu is not None:
bias = self._fp16_bias_cpu
return torch.nn.functional.linear(x, self._fp16_weight_cpu.to(dtype=x.dtype), # type: ignore[arg-type]
bias.to(dtype=x.dtype) if bias is not None else None)
raise RuntimeError(
"FP8Linear cannot run on CPU because FP16 weights are not kept. "
"Use fp16_weight_storage='cpu_offload' (or 'keep') for CPU fallback."
)
# vLLM fp8 GEMM only supports fp16/bf16 outputs.
in_dtype = x.dtype
if in_dtype not in (torch.float16, torch.bfloat16):
if not self.options.cast_inputs:
# Fall back if we still have FP16 weights.
if self.linear is not None:
return self.linear(x)
if self._fp16_weight_cpu is not None:
w = self._fp16_weight_cpu.to(device=x.device, dtype=in_dtype)
b = self._fp16_bias_cpu
b = b.to(device=x.device, dtype=in_dtype) if b is not None else None
return torch.nn.functional.linear(x, w, b)
raise RuntimeError(
"cast_inputs=False requires FP16 weights for fallback, but they were discarded."
)
# import nvtx
# nvtx.push_range(f"cast_input")
x_fp = x.to(torch.bfloat16)
# nvtx.pop_range()
out_dtype = torch.bfloat16
else:
x_fp = x
out_dtype = in_dtype
self._maybe_requantize_weight(x_fp.device)
if self.linear is not None:
bias = self.linear.bias
else:
bias = self.bias
if bias is not None:
if bias.device != x_fp.device:
bias = bias.to(device=x_fp.device, non_blocking=True)
if bias.dtype != out_dtype:
bias = bias.to(dtype=out_dtype)
y = self._fp8_linear_op.apply(
input=x_fp,
weight=self._fp8_weight, # type: ignore[arg-type]
weight_scale=self._fp8_weight_scale, # type: ignore[arg-type]
out_dtype=out_dtype,
input_scale=None, # dynamic activation scaling
bias=bias,
)
if self.options.cast_inputs and self.options.cast_output_back and y.dtype != in_dtype:
return y.to(in_dtype)
return y
def enable_fp8_gemm(
model: nn.Module,
*,
options: FP8GemmOptions = FP8GemmOptions(),
module_filter: Optional[Callable[[str, nn.Module], bool]] = None,
inplace: bool = True,
) -> nn.Module:
"""
Replace nn.Linear modules in a model with FP8Linear to accelerate GEMMs.
Args:
model: Any torch.nn.Module.
options: FP8GemmOptions controlling casting / fallback behavior.
module_filter: Optional predicate (name, module) -> bool to decide
whether to wrap a given module. If None, wraps all nn.Linear.
inplace: If True, modifies model in-place and returns it.
Returns:
The modified model (same object if inplace=True).
"""
if not inplace:
import copy
model = copy.deepcopy(model)
def should_wrap(name: str, m: nn.Module) -> bool:
if not isinstance(m, nn.Linear):
return False
if module_filter is None:
return True
return bool(module_filter(name, m))
def _recurse(prefix: str, parent: nn.Module) -> None:
for child_name, child in list(parent.named_children()):
full_name = f"{prefix}.{child_name}" if prefix else child_name
if should_wrap(full_name, child):
fp8_mod = FP8Linear.from_linear(child, options=options)
# Optionally materialize immediately while the original weight is
# already on CUDA, so we can discard/offload FP16 weights right away.
if options.materialize_fp8_on_wrap and child.weight.is_cuda:
fp8_mod.materialize_fp8_weight(child.weight.device)
setattr(parent, child_name, fp8_mod)
else:
_recurse(full_name, child)
_recurse("", model)
return model