[Feature] Framework-level torch.compile for S2-Pro codebook decoder Phase 1#239
[Feature] Framework-level torch.compile for S2-Pro codebook decoder Phase 1#239yxs wants to merge 1 commit intosgl-project:mainfrom
Conversation
…Phase 1) Implements Phase 1 of sgl-project#172: partial compile on the codebook decoder, +33% tok/s on S2-Pro TTS with fullgraph=True and zero graph breaks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds Phase 1 framework-level torch.compile support for S2-Pro by extracting compile-friendly decode logic and providing a reusable compile-target wiring utility in the omni engine layer.
Changes:
- Extracts S2-Pro codebook decoding into a standalone function and routes decode through a swappable function pointer to enable compilation.
- Introduces
sglang_omni.engines.omni.compile.apply_compile_targets()to compile model-registered targets with fullgraph + fallback mode. - Plumbs a
compile_leveloption through the S2-Pro pipeline/factory and adds an example YAML enabling partial compile.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
sglang_omni/models/fishaudio_s2_pro/sglang_model.py |
Extracts _decode_codebooks_impl and adds compile target getter/setter + function indirection. |
sglang_omni/models/fishaudio_s2_pro/runtime/s2pro_sglang_ar.py |
Removes in-class compile flag and exposes compile targets via getter/setter on output processor. |
sglang_omni/models/fishaudio_s2_pro/pipeline/stages.py |
Adds compile_level arg and forwards it into engine creation. |
sglang_omni/models/fishaudio_s2_pro/factory.py |
Applies framework-level compile targets when compile_level == "partial". |
sglang_omni/engines/omni/compile.py |
New framework module that discovers/compiles registered targets with fallback behavior. |
examples/configs/s2pro_tts_compile.yaml |
New example pipeline config enabling partial compile for the TTS engine stage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Constrained semantic argmax + codebook generation. | ||
|
|
||
| Standalone function so torch.compile can trace it with fullgraph=True. | ||
| Returns (output_codes [bs, num_codebooks+1], semantic_ids [bs]). |
There was a problem hiding this comment.
_decode_codebooks_impl returns the selected semantic token id (vocab id), but the docstring and the caller name it semantic_ids. This is easy to confuse with sem_id (semantic_token - semantic_begin_id). Consider renaming the returned value / local variable to semantic_token (or semantic_token_ids) and updating the docstring accordingly to prevent accidental misuse later.
| Returns (output_codes [bs, num_codebooks+1], semantic_ids [bs]). | |
| Returns (output_codes [bs, num_codebooks+1], semantic_token [bs]). |
| stream_stride: int = 5, | ||
| stream_followup_stride: int = 100, | ||
| stream_vocoder_device: str | None = None, | ||
| compile_level: str = "none", | ||
| ) -> EngineExecutor: |
There was a problem hiding this comment.
compile_level is user-facing config (YAML) but is currently not validated here. Consider validating allowed values ("none"/"partial" for now) and raising a clear ValueError (or warning) on invalid values so typos don’t silently disable compile.
| server_args.attention_backend = "fa3" | ||
|
|
||
| # Enable hidden state capture for unified decode | ||
| use_partial_compile = compile_level == "partial" | ||
|
|
||
| # Enable hidden state capture for unified decode or partial compile |
There was a problem hiding this comment.
compile_level is treated as a free-form string (only "partial" changes behavior; everything else silently behaves like "none"). To avoid misconfiguration being silently ignored, validate compile_level against the supported values (e.g., {"none","partial"} for phase 1) and raise a ValueError (or at least warn) on unknown values.
| import torch._dynamo.config | ||
| import torch._inductor.config | ||
|
|
||
| torch._inductor.config.coordinate_descent_tuning = True | ||
| torch._inductor.config.triton.unique_kernel_names = True | ||
| torch._inductor.config.fx_graph_cache = True | ||
| torch._dynamo.config.accumulated_cache_size_limit = 1024 | ||
|
|
||
|
|
There was a problem hiding this comment.
_set_inductor_config() sets several private torch._dynamo/_inductor config attributes. If any of these fields are missing/renamed in the current PyTorch build, enabling compile will raise during engine creation instead of gracefully continuing in eager mode. Consider wrapping these assignments in try/except (or hasattr checks) and logging exc_info=True.
| import torch._dynamo.config | |
| import torch._inductor.config | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.triton.unique_kernel_names = True | |
| torch._inductor.config.fx_graph_cache = True | |
| torch._dynamo.config.accumulated_cache_size_limit = 1024 | |
| try: | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.triton.unique_kernel_names = True | |
| torch._inductor.config.fx_graph_cache = True | |
| torch._dynamo.config.accumulated_cache_size_limit = 1024 | |
| except Exception: | |
| logger.warning( | |
| "Failed to set torch._dynamo/torch._inductor compile configuration; " | |
| "continuing without custom inductor settings.", | |
| exc_info=True, | |
| ) |
| def apply_compile_targets( | ||
| *models: Any, | ||
| compile_mode: str = DEFAULT_COMPILE_MODE, | ||
| ) -> list[str]: | ||
| """Compile all registered targets on one or more models. |
There was a problem hiding this comment.
This new framework API (apply_compile_targets / setter naming convention / fallback behavior) isn’t covered by unit tests. Consider adding a small test that stubs torch.compile and asserts targets are discovered and wired via set_compiled__fn, and that compile failure cleanly falls back to eager.
Modifications
Framework-level torch.compile support (Phase 1 of #172)
Benchmark & Profiling