|
121 | 121 | # is needed, set this to ``False`` and compile manually after moving the |
122 | 122 | # pipeline to CUDA. |
123 | 123 | "few_shot_auto_compile": False, |
| 124 | + # When enabled, SVDQ fuses the first quantized linear layer, GELU activation, |
| 125 | + # and second quantized linear layer in standard diffusers ``FeedForward`` GELU |
| 126 | + # MLP blocks into a single kernel chain via ``svdq_gemm_w4a4_ext``. The |
| 127 | + # intermediate fp16 activation is never written to HBM — the first GEMM |
| 128 | + # directly produces 4-bit quantized output consumed by the second GEMM. |
| 129 | + # Requires the ``fused_gelu_mlp`` and ``fused_gelu_proj`` passes to be |
| 130 | + # active; has no effect on models that use GEGLU, SwiGLU, or custom |
| 131 | + # FeedForward structures. |
| 132 | + "fused_mlp": False, |
124 | 133 | } |
125 | 134 |
|
126 | 135 |
|
@@ -299,6 +308,7 @@ def _resolve_svdq_kwargs(svdq_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any |
299 | 308 | "few_shot_relax_top_ratio": _resolve_svdq_ratio, |
300 | 309 | "few_shot_relax_strategy": _resolve_svdq_few_shot_relax_strategy, |
301 | 310 | "few_shot_auto_compile": _resolve_svdq_bool_kwarg, |
| 311 | + "fused_mlp": _resolve_svdq_bool_kwarg, |
302 | 312 | } |
303 | 313 | for key, value in svdq_kwargs.items(): |
304 | 314 | resolved[key] = validators[key](key, value) |
@@ -589,15 +599,17 @@ def strify(self) -> str: |
589 | 599 |
|
590 | 600 | def _stringify_quant_type(quant_type: str) -> str: |
591 | 601 | quant_type = quant_type.lower() |
592 | | - if quant_type.startswith("svdq") and quant_type.endswith("_dq"): |
| 602 | + if quant_type.startswith("svdq"): |
593 | 603 | svdq_kwargs = self.get_svdq_kwargs() |
594 | | - smooth_strategy = svdq_kwargs.get("smooth_strategy", "identity") |
595 | | - if smooth_strategy != "identity": |
596 | | - quant_type = f"{quant_type}_{smooth_strategy}" |
597 | | - if smooth_strategy == "few_shot": |
598 | | - relax_strategy = svdq_kwargs.get("few_shot_relax_strategy", "auto") |
599 | | - quant_type = f"{quant_type}_{relax_strategy}" |
600 | | - return quant_type |
| 604 | + if quant_type.endswith("_dq"): |
| 605 | + smooth_strategy = svdq_kwargs.get("smooth_strategy", "identity") |
| 606 | + if smooth_strategy != "identity": |
| 607 | + quant_type = f"{quant_type}_{smooth_strategy}" |
| 608 | + if smooth_strategy == "few_shot": |
| 609 | + relax_strategy = svdq_kwargs.get("few_shot_relax_strategy", "auto") |
| 610 | + quant_type = f"{quant_type}_{relax_strategy}" |
| 611 | + if svdq_kwargs.get("fused_mlp", False): |
| 612 | + quant_type = f"{quant_type}_fused_mlp" |
601 | 613 | return quant_type |
602 | 614 |
|
603 | 615 | if self.components_to_quantize is None or isinstance(self.components_to_quantize, list): |
|
0 commit comments