Skip to content

Commit f5b607c

Browse files
[Feat] Expose Encoder Mem Reserve As --encoder-mem-reserve CLI Flag (#339)
1 parent 06afb6c commit f5b607c

7 files changed

Lines changed: 259 additions & 86 deletions

File tree

examples/run_qwen3_omni_server.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ def parse_args() -> argparse.Namespace:
6969
"If omitted, SGLang chooses automatically."
7070
),
7171
)
72+
parser.add_argument(
73+
"--encoder-mem-reserve",
74+
type=float,
75+
default=None,
76+
help=(
77+
"GPU-memory fraction kept OUT of SGLang's static pool (model weights "
78+
"+ KV cache) and left free for the co-located vision/audio encoder's "
79+
"weights and activations on the thinker GPU.\n"
80+
"Behavior across the four flag combinations of --mem-fraction-static "
81+
"and --encoder-mem-reserve:\n"
82+
" (1) neither flag passed: SGLang auto-selects mem_fraction_static "
83+
"and the default reserve 0.05 is subtracted;\n"
84+
" (2) only --encoder-mem-reserve X: SGLang auto-selects "
85+
"mem_fraction_static and X is subtracted;\n"
86+
" (3) only --mem-fraction-static X: X is used verbatim and the "
87+
"default reserve is ignored;\n"
88+
" (4) both flags: rejected at CLI as mutually exclusive.\n"
89+
"Default 0.05 is tuned for single-request / short-video workloads; "
90+
"raise to 0.15-0.20 for high-concurrency long-video or long-audio "
91+
"workloads."
92+
),
93+
)
7294

7395
# Server
7496
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -83,14 +105,32 @@ def parse_args() -> argparse.Namespace:
83105
return parser.parse_args()
84106

85107

108+
def _check_mem_flag_mutex(
109+
mem_fraction_static: float | None,
110+
encoder_mem_reserve: float | None,
111+
) -> None:
112+
"""Reject passing both --mem-fraction-static and --encoder-mem-reserve."""
113+
if mem_fraction_static is not None and encoder_mem_reserve is not None:
114+
raise ValueError(
115+
"--mem-fraction-static and --encoder-mem-reserve are mutually "
116+
"exclusive: --mem-fraction-static pins the pool size directly "
117+
"and the reserve only subtracts from SGLang's auto-selected "
118+
"value. Pass only one."
119+
)
120+
121+
86122
def main() -> None:
87123
args = parse_args()
88124

125+
_check_mem_flag_mutex(args.mem_fraction_static, args.encoder_mem_reserve)
126+
89127
overrides = {}
90128
if args.thinker_max_seq_len is not None:
91129
overrides["thinker_max_seq_len"] = args.thinker_max_seq_len
92130
if args.cpu_offload_gb:
93131
overrides["cpu_offload_gb"] = args.cpu_offload_gb
132+
if args.encoder_mem_reserve is not None:
133+
overrides["encoder_mem_reserve"] = args.encoder_mem_reserve
94134

95135
config = Qwen3OmniPipelineConfig(
96136
model_path=args.model_path,

sglang_omni/engines/ar/sglang_backend/server_args_builder.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,6 @@
66

77
from sglang.srt.server_args import ServerArgs
88

9-
# Note (Ratish, Chenyang):
10-
11-
# SGLang's VLM auto-sizing applies a dynamic 0.95 * factor reserve
12-
# (roughly [0.8, 1.05]); Qwen3-Omni nests vision/audio configs under
13-
# `thinker_config` so SGLang's VLM path never triggers for us. 0.05
14-
# is a conservative linear lower-bound of that dynamic reserve; we
15-
# subtract it after auto-sizing when the thinker GPU also hosts encoder
16-
# stages. User-pinned mem_fraction_static bypasses this reserve.
17-
18-
OMNI_ENCODER_MEM_FRACTION_STATIC_RESERVE = 0.05
19-
209

2110
def build_sglang_server_args(
2211
model_path: str,
@@ -26,10 +15,9 @@ def build_sglang_server_args(
2615
max_prefill_tokens: int = 4096,
2716
max_running_requests: int = 16,
2817
mem_fraction_static: float | None = None,
29-
auto_mem_fraction_static_reserve: float | None = None,
3018
**overrides: Any,
3119
) -> ServerArgs:
32-
"""Build ServerArgs with shared defaults for all SGLang AR engines."""
20+
"""Build a SGLang ServerArgs with shared defaults for AR engines."""
3321
kwargs: dict[str, Any] = {
3422
"model_path": model_path,
3523
"trust_remote_code": True,
@@ -45,30 +33,36 @@ def build_sglang_server_args(
4533
if mem_fraction_static is not None:
4634
kwargs["mem_fraction_static"] = mem_fraction_static
4735
kwargs.update(overrides)
48-
server_args = ServerArgs(**kwargs)
49-
_apply_auto_mem_fraction_static_reserve(
50-
server_args,
51-
enabled=auto_mem_fraction_static_reserve is not None,
52-
user_mem_fraction_static=mem_fraction_static,
53-
reserve=auto_mem_fraction_static_reserve or 0.0,
54-
)
55-
return server_args
36+
return ServerArgs(**kwargs)
5637

5738

58-
def _apply_auto_mem_fraction_static_reserve(
39+
def apply_encoder_mem_reserve(
5940
server_args: ServerArgs,
60-
*,
61-
enabled: bool,
62-
user_mem_fraction_static: float | None,
63-
reserve: float,
41+
encoder_mem_reserve: float,
6442
) -> None:
65-
"""Subtract a caller-requested reserve from SGLang's auto-selected value."""
66-
if not enabled or user_mem_fraction_static is not None:
67-
return
68-
if reserve <= 0:
69-
return
43+
"""Subtract encoder_mem_reserve from SGLang's auto-picked mem_fraction_static.
7044
45+
# Note (Chenyang):
46+
Call this only when SGLang auto-selected mem_fraction_static —
47+
i.e. the caller did NOT pin --mem-fraction-static. When the caller
48+
pinned, that value is the whole budget and the reserve value is ignored.
49+
50+
Raises ValueError when the result would drop below 0.1 — below
51+
that, SGLang's KV allocator fails deep in the scheduler with a
52+
confusing traceback (empirically crashes ~0.08 on H200 for
53+
Qwen3-Omni-30B), so surface it at build time instead.
54+
"""
55+
if encoder_mem_reserve <= 0:
56+
return
7157
current = server_args.mem_fraction_static
7258
if current is None:
7359
return
74-
server_args.mem_fraction_static = round(max(0.01, current - reserve), 3)
60+
new_value = current - encoder_mem_reserve
61+
if new_value < 0.1:
62+
raise ValueError(
63+
f"auto mem_fraction_static {current:.3f} minus encoder_mem_reserve "
64+
f"{encoder_mem_reserve:.3f} = {new_value:.3f} is below the safe "
65+
f"floor 0.1; lower encoder_mem_reserve or pin "
66+
f"--mem-fraction-static explicitly."
67+
)
68+
server_args.mem_fraction_static = round(new_value, 3)

sglang_omni/models/ming_omni/pipeline/stages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
from sglang_omni.engines.ar.sglang_backend.server_args_builder import (
9-
OMNI_ENCODER_MEM_FRACTION_STATIC_RESERVE,
9+
apply_encoder_mem_reserve,
1010
build_sglang_server_args,
1111
)
1212
from sglang_omni.engines.omni import create_sglang_ar_engine, create_single_pass_engine
@@ -338,9 +338,10 @@ def create_sglang_thinker_executor_from_config(
338338
server_args = build_sglang_server_args(
339339
local_path,
340340
context_length=thinker_max_seq_len,
341-
auto_mem_fraction_static_reserve=OMNI_ENCODER_MEM_FRACTION_STATIC_RESERVE,
342341
**overrides,
343342
)
343+
if "mem_fraction_static" not in overrides:
344+
apply_encoder_mem_reserve(server_args, 0.05)
344345
pre_load_mem = (
345346
f", pre_load_avail_mem={pre_load_avail_mem:.2f} GB"
346347
if pre_load_avail_mem is not None

sglang_omni/models/qwen3_omni/config.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,42 @@ def apply_server_args_overrides(
113113
and overrides["tp_size"] > 1
114114
):
115115
raise NotImplementedError("Qwen3-Omni TP is not supported yet.")
116-
remaining = _route_thinker_max_seq_len(self.stages, stage_name, overrides)
116+
remaining = _route_thinker_executor_args(self.stages, stage_name, overrides)
117117
if remaining:
118118
super().apply_server_args_overrides(
119119
stage_name=stage_name,
120120
overrides=remaining,
121121
)
122122

123123

124-
def _route_thinker_max_seq_len(
124+
def _route_thinker_executor_args(
125125
stages: list[StageConfig],
126126
stage_name: str,
127127
overrides: dict[str, Any],
128128
) -> dict[str, Any]:
129+
"""Pop thinker-factory kwargs onto the thinker stage; return the rest."""
129130
remaining = dict(overrides)
130-
thinker_max_seq_len = remaining.pop("thinker_max_seq_len", None)
131-
if thinker_max_seq_len is None or stage_name != THINKER_STAGE:
131+
if stage_name != THINKER_STAGE:
132132
return remaining
133-
for stage in stages:
134-
if stage.name == THINKER_STAGE:
135-
stage.executor.args["thinker_max_seq_len"] = int(thinker_max_seq_len)
136-
break
133+
134+
casted: dict[str, Any] = {}
135+
136+
seq_len = remaining.pop("thinker_max_seq_len", None)
137+
if seq_len is not None:
138+
casted["thinker_max_seq_len"] = int(seq_len)
139+
140+
reserve = remaining.pop("encoder_mem_reserve", None)
141+
if reserve is not None:
142+
reserve = float(reserve)
143+
if not 0.0 <= reserve < 1.0:
144+
raise ValueError(f"encoder_mem_reserve must be in [0, 1), got {reserve}")
145+
casted["encoder_mem_reserve"] = reserve
146+
147+
if casted:
148+
for stage in stages:
149+
if stage.name == THINKER_STAGE:
150+
stage.executor.args.update(casted)
151+
break
137152
return remaining
138153

139154

@@ -302,7 +317,7 @@ def apply_server_args_overrides(
302317
)
303318
if tp_size > 1:
304319
raise NotImplementedError("Qwen3-Omni TP is not supported yet.")
305-
remaining = _route_thinker_max_seq_len(self.stages, stage_name, overrides)
320+
remaining = _route_thinker_executor_args(self.stages, stage_name, overrides)
306321
if remaining:
307322
super().apply_server_args_overrides(
308323
stage_name=stage_name,

sglang_omni/models/qwen3_omni/pipeline/stages.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from transformers import AutoTokenizer
1111

1212
from sglang_omni.engines.ar.sglang_backend.server_args_builder import (
13-
OMNI_ENCODER_MEM_FRACTION_STATIC_RESERVE,
13+
apply_encoder_mem_reserve,
1414
build_sglang_server_args,
1515
)
1616
from sglang_omni.engines.omni import (
@@ -354,21 +354,20 @@ def create_sglang_thinker_executor_from_config(
354354
*,
355355
gpu_id: int = 0,
356356
thinker_max_seq_len: int = 8192,
357+
encoder_mem_reserve: float = 0.05,
357358
server_args_overrides: dict[str, Any] | None = None,
358359
speech_enabled: bool = False,
359360
) -> EngineExecutor:
360-
"""Create a SGLang thinker executor from JSON-serializable config args.
361-
362-
This keeps pipeline config args plain dict types while still constructing
363-
a typed ServerArgs object internally.
364-
"""
361+
"""Create a SGLang thinker executor from JSON-serializable config args."""
365362
pre_load_avail_mem = avail_gpu_mem(gpu_id)
363+
overrides = server_args_overrides or {}
366364
server_args = build_sglang_server_args(
367365
model_path,
368366
context_length=thinker_max_seq_len,
369-
auto_mem_fraction_static_reserve=OMNI_ENCODER_MEM_FRACTION_STATIC_RESERVE,
370-
**(server_args_overrides or {}),
367+
**overrides,
371368
)
369+
if "mem_fraction_static" not in overrides:
370+
apply_encoder_mem_reserve(server_args, encoder_mem_reserve)
372371
pre_load_mem = (
373372
f" pre_load_avail_mem={pre_load_avail_mem:.2f} GB"
374373
if pre_load_avail_mem is not None

0 commit comments

Comments
 (0)