Skip to content

Commit 3bb29cb

Browse files
vmendelevclaude
andcommitted
Remove backend-specific args from serve_unified.py
Replace ~20 hard-coded backend-specific CLI arguments with a generic parse_extra_args() that converts unknown flags to a config dict. This makes serve_unified.py truly backend-agnostic — new backends no longer need to edit the server entrypoint. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 225e674 commit 3bb29cb

1 file changed

Lines changed: 87 additions & 70 deletions

File tree

nemo_skills/inference/server/serve_unified.py

Lines changed: 87 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@
3232
--codec_model /path/to/codec \\
3333
--port 8000
3434
35+
Backend-specific options are passed as extra CLI flags and forwarded to the
36+
backend's config dataclass automatically. For example:
37+
38+
--server_args "--backend magpie_tts --codec_model /path --use_cfg --cfg_scale 2.5"
39+
40+
Any flag not recognized by the server itself is parsed generically:
41+
--flag -> {"flag": True}
42+
--key value -> {"key": <auto-typed value>}
43+
--key=value -> {"key": <auto-typed value>}
44+
--no_flag -> {"flag": False}
45+
46+
See each backend's config class for available options (e.g. MagpieTTSConfig).
47+
3548
Example YAML config (backend_config.yaml):
3649
backend: magpie_tts
3750
model_path: /path/to/model
@@ -111,51 +124,63 @@ def load_yaml_config(config_path: str) -> dict:
111124
return yaml.safe_load(f)
112125

113126

114-
def build_config_from_args(args) -> dict:
115-
"""Build config dict from CLI arguments (backward-compatible mode)."""
116-
config_dict = {
117-
"model_path": args.model,
118-
"device": args.device,
119-
"dtype": args.dtype,
120-
"max_new_tokens": args.max_new_tokens,
121-
"temperature": args.temperature,
122-
"top_p": args.top_p,
123-
}
124-
125-
if args.codec_model:
126-
config_dict["codec_model_path"] = args.codec_model
127-
128-
# Pass through any backend-specific args that were set
129-
for key in [
130-
"top_k",
131-
"use_cfg",
132-
"cfg_scale",
133-
"decoder_only_model",
134-
"phoneme_input_type",
135-
"use_local_transformer",
136-
"hparams_file",
137-
"checkpoint_file",
138-
"legacy_codebooks",
139-
"legacy_text_conditioning",
140-
"hparams_from_wandb",
141-
"prompt_format",
142-
"ignore_system_prompt",
143-
"silence_padding_sec",
144-
"config_path",
145-
"llm_checkpoint_path",
146-
"tts_checkpoint_path",
147-
"speaker_reference",
148-
"num_frames_per_inference",
149-
]:
150-
val = getattr(args, key, None)
151-
if val is not None:
152-
config_dict[key] = val
153-
154-
# Handle store_true flags
155-
if getattr(args, "no_decode_audio", False):
156-
config_dict["decode_audio"] = False
157-
158-
return config_dict
127+
def _coerce_value(value: str):
128+
"""Try to coerce a string value to int, float, or bool."""
129+
try:
130+
return int(value)
131+
except ValueError:
132+
pass
133+
try:
134+
return float(value)
135+
except ValueError:
136+
pass
137+
if value.lower() == "true":
138+
return True
139+
if value.lower() == "false":
140+
return False
141+
return value
142+
143+
144+
def parse_extra_args(extra_args: list) -> dict:
145+
"""Convert unknown CLI args to a config dict.
146+
147+
Handles these patterns:
148+
--flag -> {"flag": True}
149+
--key value -> {"key": <auto-typed value>}
150+
--key=value -> {"key": <auto-typed value>}
151+
--no_flag -> {"flag": False} (strip no_ prefix)
152+
"""
153+
result = {}
154+
i = 0
155+
while i < len(extra_args):
156+
arg = extra_args[i]
157+
if not arg.startswith("--"):
158+
i += 1
159+
continue
160+
161+
# Handle --key=value
162+
if "=" in arg:
163+
key, value = arg[2:].split("=", 1)
164+
result[key] = _coerce_value(value)
165+
i += 1
166+
continue
167+
168+
key = arg[2:]
169+
170+
# Check if next token is a value (not another flag)
171+
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("--"):
172+
result[key] = _coerce_value(extra_args[i + 1])
173+
i += 2
174+
continue
175+
176+
# Bare flag: --no_X -> {X: False}, otherwise {key: True}
177+
if key.startswith("no_"):
178+
result[key[3:]] = False
179+
else:
180+
result[key] = True
181+
i += 1
182+
183+
return result
159184

160185

161186
def main():
@@ -187,37 +212,16 @@ def main():
187212
parser.add_argument("--device", default="cuda", help="Device to use")
188213
parser.add_argument("--dtype", default="bfloat16", help="Model dtype")
189214

190-
# Backend-specific options (for CLI backward compatibility)
191-
parser.add_argument("--codec_model", default=None, help="Path to codec model")
192-
parser.add_argument("--prompt_format", default=None, help="Prompt format")
193-
parser.add_argument("--phoneme_input_type", default=None, help="Phoneme input type")
194-
parser.add_argument("--decoder_only_model", action="store_true", default=None, help="Decoder-only model")
195-
parser.add_argument("--use_local_transformer", action="store_true", default=None, help="Local transformer")
196-
parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling")
197-
parser.add_argument("--use_cfg", action="store_true", default=None, help="Classifier-free guidance")
198-
parser.add_argument("--cfg_scale", type=float, default=None, help="CFG scale")
199-
parser.add_argument("--hparams_file", default=None, help="Path to hparams.yaml")
200-
parser.add_argument("--checkpoint_file", default=None, help="Path to .ckpt checkpoint")
201-
parser.add_argument("--legacy_codebooks", action="store_true", default=None, help="Legacy codebook indices")
202-
parser.add_argument("--legacy_text_conditioning", action="store_true", default=None, help="Legacy text conditioning")
203-
parser.add_argument("--hparams_from_wandb", action="store_true", default=None, help="hparams from wandb")
204-
parser.add_argument("--ignore_system_prompt", action="store_true", default=None, help="Ignore system prompts")
205-
parser.add_argument("--silence_padding_sec", type=float, default=None, help="Silence padding seconds")
206-
parser.add_argument("--config_path", default=None, help="Backend YAML config path")
207-
parser.add_argument("--llm_checkpoint_path", default=None, help="LLM checkpoint path")
208-
parser.add_argument("--tts_checkpoint_path", default=None, help="TTS checkpoint path")
209-
parser.add_argument("--speaker_reference", default=None, help="Speaker reference audio path")
210-
parser.add_argument("--num_frames_per_inference", type=int, default=None, help="Frames per inference")
211-
parser.add_argument("--no_decode_audio", action="store_true", help="Disable audio output")
212-
213215
# Environment setup
214216
parser.add_argument("--code_path", default=None, help="Path to add to PYTHONPATH")
215217
parser.add_argument("--hack_path", default=None, help="Path to safetensors patch")
216218

217219
# Debug
218220
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
219221

220-
args, extra_args = parser.parse_known_args()
222+
# Parse known args; everything else is backend-specific
223+
args, unknown = parser.parse_known_args()
224+
extra_config = parse_extra_args(unknown)
221225

222226
# Setup environment
223227
setup_pythonpath(args.code_path)
@@ -241,12 +245,23 @@ def main():
241245
# CLI overrides
242246
if args.model:
243247
config_dict["model_path"] = args.model
248+
# Merge any extra CLI args into YAML config (CLI wins)
249+
config_dict.update(extra_config)
244250
else:
245251
# CLI args mode (backward compatible)
246252
if not args.model:
247253
parser.error("--model is required when not using --config")
248254
backend_type = args.backend
249-
config_dict = build_config_from_args(args)
255+
config_dict = {
256+
"model_path": args.model,
257+
"device": args.device,
258+
"dtype": args.dtype,
259+
"max_new_tokens": args.max_new_tokens,
260+
"temperature": args.temperature,
261+
"top_p": args.top_p,
262+
}
263+
# Merge backend-specific args from extra CLI flags
264+
config_dict.update(extra_config)
250265

251266
# Print configuration
252267
print("=" * 60)
@@ -260,6 +275,8 @@ def main():
260275
print(f" Batch Timeout: {args.batch_timeout}s")
261276
if args.config:
262277
print(f" Config: {args.config}")
278+
if extra_config:
279+
print(f" Extra CLI Config: {extra_config}")
263280
print("=" * 60)
264281

265282
# Import and run

0 commit comments

Comments
 (0)