Skip to content

Commit c8491c6

Browse files
authored
svdq: add fused gelu mlp/proj pass (#1047)
* svdq: add fused gelu mlp pass * svdq: add fused gelu mlp/proj pass * svdq: add fused gelu mlp/proj pass * svdq: add fused gelu mlp/proj pass
1 parent 56b4356 commit c8491c6

9 files changed

Lines changed: 880 additions & 24 deletions

File tree

csrc/kernels/svdq/gemm_w4a4_launch_impl.cuh

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,34 +251,38 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
251251
};
252252

253253
if (qout.valid() && oscales.valid()) {
254-
// dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {
255-
256-
static constexpr float SHIFT_GELU = 0.171875f;
257-
258-
constexpr bool USE_UNSIGNED = !USE_FP4;
259-
using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
254+
// Use signed INT4 (matching fc2.quantize output).
255+
// No unsigned shift — the non-fused path never applies +0.171875
256+
// before fc2.quantize, so we must match that behaviour exactly.
257+
constexpr bool USE_UNSIGNED = false;
258+
259+
// GELU is always applied in-place via EpilogueGelu (MidEpilogue)
260+
// so that *every* downstream epilogue — including EpilogueDefault
261+
// (fp16 `out`) and EpilogueLoraDown (`lora_act_out`) — sees
262+
// post-GELU accumulator values. EpilogueQuantize only quantizes.
263+
constexpr bool FUSE_GELU = false;
264+
using EpilogueQuantize = typename GEMM::EpilogueQuantize<FUSE_GELU, USE_UNSIGNED, USE_FP4>;
260265
auto argsQuantize = typename EpilogueQuantize::Arguments{
261266
.qout = qout.data_ptr<packed_act_t>(),
262267
.oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
263-
.shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
268+
.shift_value = 0.0f,
264269
.smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()};
265270

266-
// TODO: check if gelu is needed
267271
if (out.valid()) {
268272
launch_lora.template operator()<
269273
typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
270-
typename Epilogues::EpilogueGelu>({typename GEMM::EpilogueDefault::Arguments{
271-
.out = out.data_ptr<half_t>(),
272-
.actualM = actualM,
273-
.actualN = actualN,
274-
},
275-
argsQuantize},
276-
{});
274+
typename Epilogues::EpilogueGelu>(
275+
{typename GEMM::EpilogueDefault::Arguments{
276+
.out = out.data_ptr<half_t>(),
277+
.actualM = actualM,
278+
.actualN = actualN,
279+
},
280+
argsQuantize},
281+
{});
277282
} else {
278283
launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(
279284
argsQuantize, {});
280285
}
281-
282286
} else if (out_linearattn.valid()) {
283287
assert(out_vk.valid());
284288

docs/user_guide/QUANTIZATION.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,66 @@ Quick examples for enabling SVDQ runtime kernel from the generate CLI:
896896
python3 -m cache_dit.generate flux --svdq-int4-r64-dq --compile # v1, baseline
897897
python3 -m cache_dit.generate flux --svdq-int4-r64-dq --svdq-runtime v2 --compile # v2
898898
```
899+
900+
## SVDQ with Fused MLP
901+
902+
When SVDQ quantizes a diffusers transformer whose ``FeedForward`` blocks use plain GELU activation, the default execution path runs three GPU kernels per MLP block: the first quantized GEMM, a separate GELU activation kernel, and the second quantized GEMM. Enabling fused MLP combines the first GEMM and GELU into a single kernel via ``svdq_gemm_w4a4_ext``, eliminating one kernel launch per block.
903+
904+
The feature is controlled by ``svdq_kwargs["fused_mlp"]`` and works automatically with most diffusers transformer families — no per-model configuration is needed.
905+
906+
**Quick start (CLI)**
907+
908+
```bash
909+
# Add --svdq-fused-mlp to any SVDQ generate command:
910+
python3 -m cache_dit.generate flux --svdq-int4-r32-dq --svdq-fused-mlp
911+
```
912+
913+
**Quick start (Python API — dynamic quantization)**
914+
915+
```python
916+
import torch
917+
import cache_dit
918+
from diffusers import FluxPipeline
919+
from cache_dit.quantization import QuantizeConfig
920+
921+
pipe = FluxPipeline.from_pretrained(
922+
"black-forest-labs/FLUX.1-dev",
923+
torch_dtype=torch.bfloat16,
924+
).to("cuda")
925+
926+
quant_config = QuantizeConfig(
927+
quant_type="svdq_int4_r32_dq",
928+
svdq_kwargs={"fused_mlp": True},
929+
)
930+
pipe.transformer = cache_dit.load(pipe.transformer, quant_config)
931+
932+
image = pipe("A cat holding a sign that says hello world").images[0]
933+
image.save("flux_fused_mlp.png")
934+
```
935+
936+
**Quick start (Python API — PTQ with serialized checkpoint)**
937+
938+
```python
939+
quant_config = QuantizeConfig(
940+
quant_type="svdq_int4_r32",
941+
serialize_to="./flux-svdq/",
942+
svdq_kwargs={"fused_mlp": True},
943+
calibrate_fn=my_calibrate_fn,
944+
)
945+
cache_dit.quantize(pipe.transformer, quant_config)
946+
947+
# Later, at inference time:
948+
pipe.transformer = cache_dit.load(
949+
pipe.transformer,
950+
"./flux-svdq/svdq_int4_r32.safetensors",
951+
)
952+
```
953+
954+
When ``fused_mlp`` is enabled, cache-dit applies two complementary passes:
955+
956+
| Pass | Targets | Fusion |
957+
|---|---|---|
958+
| ``fused_gelu_mlp`` | Standard ``FeedForward`` double blocks | fc1 + GELU + fc2 (qout path, no fp16 HBM write) |
959+
| ``fused_gelu_proj`` | Single-stream blocks with concat MLP | fc1 + GELU only (fp16 output, concat unchanged) |
960+
961+
Both passes use generic structural detection — they work with most diffusers transformers (FLUX, SD3, PixArt, HunyuanVideo, Wan, Cosmos, Bria, QwenImage, Chroma, Motif Video, and many more) without per-model code changes.

src/cache_dit/_utils/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,13 @@ def get_args(parse: bool = True, ) -> argparse.ArgumentParser | argparse.Namespa
674674
default=False,
675675
help="Compile the transformer only after SVDQ few-shot runtime quantization completes.",
676676
)
677+
parser.add_argument(
678+
"--svdq-fused-mlp",
679+
action="store_true",
680+
default=False,
681+
help=
682+
"Fuse FeedForward GELU MLP blocks into a single fused kernel chain after SVDQ quantization.",
683+
)
677684
# Parallelism settings
678685
parser.add_argument(
679686
"--parallel-type",
@@ -1749,6 +1756,7 @@ def _resolve_cli_svdq_kwargs() -> Optional[Dict[str, Any]]:
17491756
"few_shot_relax_top_ratio": args.svdq_few_shot_relax_top_ratio,
17501757
"few_shot_relax_strategy": args.svdq_few_shot_relax_strategy,
17511758
"few_shot_auto_compile": few_shot_auto_compile,
1759+
"fused_mlp": bool(args.svdq_fused_mlp),
17521760
}
17531761

17541762
# Quantize transformer by default if quantization is enabled

src/cache_dit/logger.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def filter(self, record):
4747
_default_file_handler = None
4848
_inference_log_file_handler = {}
4949
_warning_once_messages: set[tuple[str, str]] = set()
50+
_info_once_messages: set[tuple[str, str]] = set()
5051

5152

5253
def _warning_once(self: logging.Logger, msg, *args, **kwargs) -> None:
@@ -66,7 +67,25 @@ def _warning_once(self: logging.Logger, msg, *args, **kwargs) -> None:
6667
self.warning(msg, *args, **kwargs)
6768

6869

70+
def _info_once(self: logging.Logger, msg, *args, **kwargs) -> None:
71+
message = logging.LogRecord(
72+
name=self.name,
73+
level=logging.INFO,
74+
pathname="",
75+
lineno=0,
76+
msg=msg,
77+
args=args,
78+
exc_info=None,
79+
).getMessage()
80+
key = (self.name, message)
81+
if key in _info_once_messages:
82+
return
83+
_info_once_messages.add(key)
84+
self.info(msg, *args, **kwargs)
85+
86+
6987
logging.Logger.warning_once = _warning_once # type: ignore[attr-defined]
88+
logging.Logger.info_once = _info_once # type: ignore[attr-defined]
7089

7190

7291
def _setup_logger():

src/cache_dit/quantization/config.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@
121121
# is needed, set this to ``False`` and compile manually after moving the
122122
# pipeline to CUDA.
123123
"few_shot_auto_compile": False,
124+
# When enabled, SVDQ fuses the first quantized linear layer, GELU activation,
125+
# and second quantized linear layer in standard diffusers ``FeedForward`` GELU
126+
# MLP blocks into a single kernel chain via ``svdq_gemm_w4a4_ext``. The
127+
# intermediate fp16 activation is never written to HBM — the first GEMM
128+
# directly produces 4-bit quantized output consumed by the second GEMM.
129+
# Requires the ``fused_gelu_mlp`` and ``fused_gelu_proj`` passes to be
130+
# active; has no effect on models that use GEGLU, SwiGLU, or custom
131+
# FeedForward structures.
132+
"fused_mlp": False,
124133
}
125134

126135

@@ -299,6 +308,7 @@ def _resolve_svdq_kwargs(svdq_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any
299308
"few_shot_relax_top_ratio": _resolve_svdq_ratio,
300309
"few_shot_relax_strategy": _resolve_svdq_few_shot_relax_strategy,
301310
"few_shot_auto_compile": _resolve_svdq_bool_kwarg,
311+
"fused_mlp": _resolve_svdq_bool_kwarg,
302312
}
303313
for key, value in svdq_kwargs.items():
304314
resolved[key] = validators[key](key, value)
@@ -589,15 +599,17 @@ def strify(self) -> str:
589599

590600
def _stringify_quant_type(quant_type: str) -> str:
591601
quant_type = quant_type.lower()
592-
if quant_type.startswith("svdq") and quant_type.endswith("_dq"):
602+
if quant_type.startswith("svdq"):
593603
svdq_kwargs = self.get_svdq_kwargs()
594-
smooth_strategy = svdq_kwargs.get("smooth_strategy", "identity")
595-
if smooth_strategy != "identity":
596-
quant_type = f"{quant_type}_{smooth_strategy}"
597-
if smooth_strategy == "few_shot":
598-
relax_strategy = svdq_kwargs.get("few_shot_relax_strategy", "auto")
599-
quant_type = f"{quant_type}_{relax_strategy}"
600-
return quant_type
604+
if quant_type.endswith("_dq"):
605+
smooth_strategy = svdq_kwargs.get("smooth_strategy", "identity")
606+
if smooth_strategy != "identity":
607+
quant_type = f"{quant_type}_{smooth_strategy}"
608+
if smooth_strategy == "few_shot":
609+
relax_strategy = svdq_kwargs.get("few_shot_relax_strategy", "auto")
610+
quant_type = f"{quant_type}_{relax_strategy}"
611+
if svdq_kwargs.get("fused_mlp", False):
612+
quant_type = f"{quant_type}_fused_mlp"
601613
return quant_type
602614

603615
if self.components_to_quantize is None or isinstance(self.components_to_quantize, list):

src/cache_dit/quantization/svdquant/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,32 @@
4040

4141
from ...kernels import svdq_extension_is_available as svdq_is_available
4242
from ...kernels import svdq_get_load_error
43+
from .fused import fused_gelu_mlp
44+
from .fused import fused_gelu_proj
4345
from .linear import SVDQW4A4Linear
46+
from .passes import apply_passes
47+
from .passes import BasePass
48+
from .passes import DEFAULT_FUSED_MLP_PASSES
49+
from .passes import FusedGeluMlpPass
50+
from .passes import FusedGeluProjPass
51+
from .passes import get_pass
52+
from .passes import register_pass
4453
from .quantizer import CalibrationInputs
4554
from .quantizer import compute_smooth_scale
4655
from .quantizer import quantize_linear_svdq_w4a4
4756
from .quantizer import standardize_calibration_activations
4857
from .quantizer import validate_svdq_linear_geometry
4958

5059
__all__ = [
60+
"apply_passes",
61+
"BasePass",
62+
"DEFAULT_FUSED_MLP_PASSES",
63+
"FusedGeluMlpPass",
64+
"fused_gelu_mlp",
65+
"fused_gelu_proj",
66+
"FusedGeluProjPass",
67+
"get_pass",
68+
"register_pass",
5169
"CalibrationInputs",
5270
"SVDQW4A4Linear",
5371
"compute_smooth_scale",

0 commit comments

Comments
 (0)