Skip to content

[Feature] Framework-level torch.compile for S2-Pro codebook decoder Phase 1#239

Open
yxs wants to merge 1 commit intosgl-project:mainfrom
yxs:feat/torch-compile-phase1
Open

[Feature] Framework-level torch.compile for S2-Pro codebook decoder Phase 1#239
yxs wants to merge 1 commit intosgl-project:mainfrom
yxs:feat/torch-compile-phase1

Conversation

@yxs
Copy link
Copy Markdown
Collaborator

@yxs yxs commented Mar 31, 2026

Modifications

Framework-level torch.compile support (Phase 1 of #172)

Benchmark & Profiling

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python -m sglang_omni.cli.cli serve --model-path /home/jobuser/models/s2-pro --port 8000 --config examples/configs/s2pro_tts.yaml

python -m benchmarks.performance.tts.benchmark_tts_speed --model fishaudio/s2-pro --port 8000 --testset /home/jobuser/data/seedtts_testset/en/meta.lst --no-ref-audio --output-dir results/exp_baseline_1088

============================================================
                    TTS Benchmark Result                    
============================================================
  Model:                         fishaudio/s2-pro
  Completed requests:            1079
  Failed requests:               9
------------------------------------------------------------
  Latency mean (s):              1.319
  Latency median (s):            1.247
  Latency p95 (s):               1.951
  Latency p99 (s):               3.841
  RTF mean:                      0.2833
  RTF median:                    0.2837
  Audio duration mean (s):       4.689
  Tok/s (per-req mean):          82.2
  Tok/s (per-req median):        81.3
  Tok/s (aggregate):             82.3
  Gen tokens (mean):             101
  Gen tokens (total):            108952
  Prompt tokens (mean):          29
  Prompt tokens (total):         31550
  Throughput (req/s):            0.758
============================================================
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python -m sglang_omni.cli.cli serve --model-path /home/jobuser/models/s2-pro --port 8000 --config examples/configs/s2pro_tts_compile.yaml

python -m benchmarks.performance.tts.benchmark_tts_speed --model fishaudio/s2-pro --port 8000 --testset /home/jobuser/data/seedtts_testset/en/meta.lst --no-ref-audio --output-dir results/exp_compile_1088

============================================================
                    TTS Benchmark Result                    
============================================================
  Model:                         fishaudio/s2-pro
  Completed requests:            1074
  Failed requests:               14
------------------------------------------------------------
  Latency mean (s):              0.984
  Latency median (s):            0.941
  Latency p95 (s):               1.418
  Latency p99 (s):               2.594
  RTF mean:                      0.2141
  RTF median:                    0.2128
  Audio duration mean (s):       4.633
  Tok/s (per-req mean):          111.6
  Tok/s (per-req median):        112.5
  Tok/s (aggregate):             111.8
  Gen tokens (mean):             100
  Gen tokens (total):            107152
  Prompt tokens (mean):          29
  Prompt tokens (total):         31420
  Throughput (req/s):            1.016
============================================================
Config Tok/s (mean) Tok/s (aggregate) Failed
Baseline 82.2 82.3 9
Partial compile 111.6 111.8 14
Delta +36% +36%

…Phase 1)

  Implements Phase 1 of sgl-project#172: partial compile on the codebook decoder,
  +33% tok/s on S2-Pro TTS with fullgraph=True and zero graph breaks.

  Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yxs yxs requested a review from FrankLeeeee as a code owner March 31, 2026 01:38
Copilot AI review requested due to automatic review settings March 31, 2026 01:38
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Phase 1 framework-level torch.compile support for S2-Pro by extracting compile-friendly decode logic and providing a reusable compile-target wiring utility in the omni engine layer.

Changes:

  • Extracts S2-Pro codebook decoding into a standalone function and routes decode through a swappable function pointer to enable compilation.
  • Introduces sglang_omni.engines.omni.compile.apply_compile_targets() to compile model-registered targets with fullgraph + fallback mode.
  • Plumbs a compile_level option through the S2-Pro pipeline/factory and adds an example YAML enabling partial compile.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
sglang_omni/models/fishaudio_s2_pro/sglang_model.py Extracts _decode_codebooks_impl and adds compile target getter/setter + function indirection.
sglang_omni/models/fishaudio_s2_pro/runtime/s2pro_sglang_ar.py Removes in-class compile flag and exposes compile targets via getter/setter on output processor.
sglang_omni/models/fishaudio_s2_pro/pipeline/stages.py Adds compile_level arg and forwards it into engine creation.
sglang_omni/models/fishaudio_s2_pro/factory.py Applies framework-level compile targets when compile_level == "partial".
sglang_omni/engines/omni/compile.py New framework module that discovers/compiles registered targets with fallback behavior.
examples/configs/s2pro_tts_compile.yaml New example pipeline config enabling partial compile for the TTS engine stage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""Constrained semantic argmax + codebook generation.

Standalone function so torch.compile can trace it with fullgraph=True.
Returns (output_codes [bs, num_codebooks+1], semantic_ids [bs]).
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_decode_codebooks_impl returns the selected semantic token id (vocab id), but the docstring and the caller name it semantic_ids. This is easy to confuse with sem_id (semantic_token - semantic_begin_id). Consider renaming the returned value / local variable to semantic_token (or semantic_token_ids) and updating the docstring accordingly to prevent accidental misuse later.

Suggested change
Returns (output_codes [bs, num_codebooks+1], semantic_ids [bs]).
Returns (output_codes [bs, num_codebooks+1], semantic_token [bs]).

Copilot uses AI. Check for mistakes.
Comment on lines 271 to 275
stream_stride: int = 5,
stream_followup_stride: int = 100,
stream_vocoder_device: str | None = None,
compile_level: str = "none",
) -> EngineExecutor:
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_level is user-facing config (YAML) but is currently not validated here. Consider validating allowed values ("none"/"partial" for now) and raising a clear ValueError (or warning) on invalid values so typos don’t silently disable compile.

Copilot uses AI. Check for mistakes.
Comment on lines 96 to +100
server_args.attention_backend = "fa3"

# Enable hidden state capture for unified decode
use_partial_compile = compile_level == "partial"

# Enable hidden state capture for unified decode or partial compile
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_level is treated as a free-form string (only "partial" changes behavior; everything else silently behaves like "none"). To avoid misconfiguration being silently ignored, validate compile_level against the supported values (e.g., {"none","partial"} for phase 1) and raise a ValueError (or at least warn) on unknown values.

Copilot uses AI. Check for mistakes.
Comment on lines +27 to +35
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._dynamo.config.accumulated_cache_size_limit = 1024


Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_set_inductor_config() sets several private torch._dynamo/_inductor config attributes. If any of these fields are missing/renamed in the current PyTorch build, enabling compile will raise during engine creation instead of gracefully continuing in eager mode. Consider wrapping these assignments in try/except (or hasattr checks) and logging exc_info=True.

Suggested change
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._dynamo.config.accumulated_cache_size_limit = 1024
try:
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._dynamo.config.accumulated_cache_size_limit = 1024
except Exception:
logger.warning(
"Failed to set torch._dynamo/torch._inductor compile configuration; "
"continuing without custom inductor settings.",
exc_info=True,
)

Copilot uses AI. Check for mistakes.
Comment on lines +36 to +40
def apply_compile_targets(
*models: Any,
compile_mode: str = DEFAULT_COMPILE_MODE,
) -> list[str]:
"""Compile all registered targets on one or more models.
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new framework API (apply_compile_targets / setter naming convention / fallback behavior) isn’t covered by unit tests. Consider adding a small test that stubs torch.compile and asserts targets are discovered and wired via set_compiled__fn, and that compile failure cleanly falls back to eager.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants