|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +""" |
| 17 | +CLI entrypoint for the Unified NeMo Inference Server. |
| 18 | +
|
| 19 | +Configuration is YAML-based: provide a config file with backend type and |
| 20 | +all backend-specific parameters. The config is validated against the |
| 21 | +backend's config class. |
| 22 | +
|
| 23 | +Usage: |
| 24 | + python -m nemo_skills.inference.server.serve_unified \\ |
| 25 | + --config /path/to/backend_config.yaml \\ |
| 26 | + --port 8000 |
| 27 | +
|
| 28 | + # Or with --model for nemo-skills pipeline compatibility: |
| 29 | + python -m nemo_skills.inference.server.serve_unified \\ |
| 30 | + --model /path/to/model \\ |
| 31 | + --backend magpie_tts \\ |
| 32 | + --codec_model /path/to/codec \\ |
| 33 | + --port 8000 |
| 34 | +
|
| 35 | +Example YAML config (backend_config.yaml): |
| 36 | + backend: magpie_tts |
| 37 | + model_path: /path/to/model |
| 38 | + codec_model_path: /path/to/codec |
| 39 | + device: cuda |
| 40 | + dtype: bfloat16 |
| 41 | + temperature: 0.6 |
| 42 | + top_k: 80 |
| 43 | + use_cfg: true |
| 44 | + cfg_scale: 2.5 |
| 45 | +""" |
| 46 | + |
| 47 | +import argparse |
| 48 | +import inspect |
| 49 | +import os |
| 50 | +import shutil |
| 51 | +import sys |
| 52 | +from typing import Optional |
| 53 | + |
| 54 | + |
| 55 | +def setup_pythonpath(code_path: Optional[str] = None): |
| 56 | + """Set up PYTHONPATH for NeMo and the unified server. |
| 57 | +
|
| 58 | + Args: |
| 59 | + code_path: Single path or colon-separated paths to add to PYTHONPATH |
| 60 | + """ |
| 61 | + paths_to_add = [] |
| 62 | + |
| 63 | + if code_path: |
| 64 | + for path in code_path.split(":"): |
| 65 | + if path and path not in paths_to_add: |
| 66 | + paths_to_add.append(path) |
| 67 | + |
| 68 | + # Add recipes path for unified server imports |
| 69 | + this_dir = os.path.dirname(os.path.abspath(__file__)) |
| 70 | + ns_eval_root = os.path.dirname(os.path.dirname(os.path.dirname(this_dir))) |
| 71 | + if os.path.exists(os.path.join(ns_eval_root, "recipes")): |
| 72 | + paths_to_add.append(ns_eval_root) |
| 73 | + |
| 74 | + # Container pattern |
| 75 | + if os.path.exists("/nemo_run/code"): |
| 76 | + paths_to_add.append("/nemo_run/code") |
| 77 | + |
| 78 | + current_path = os.environ.get("PYTHONPATH", "") |
| 79 | + for path in paths_to_add: |
| 80 | + if path not in current_path.split(":"): |
| 81 | + current_path = f"{path}:{current_path}" if current_path else path |
| 82 | + |
| 83 | + os.environ["PYTHONPATH"] = current_path |
| 84 | + |
| 85 | + for path in paths_to_add: |
| 86 | + if path not in sys.path: |
| 87 | + sys.path.insert(0, path) |
| 88 | + |
| 89 | + |
| 90 | +def apply_safetensors_patch(hack_path: Optional[str]): |
| 91 | + """Apply safetensors patch if provided (for some NeMo models).""" |
| 92 | + if not hack_path or not os.path.exists(hack_path): |
| 93 | + return |
| 94 | + |
| 95 | + try: |
| 96 | + import safetensors.torch as st_torch |
| 97 | + |
| 98 | + dest_path = inspect.getfile(st_torch) |
| 99 | + os.makedirs(os.path.dirname(dest_path), exist_ok=True) |
| 100 | + shutil.copyfile(hack_path, dest_path) |
| 101 | + print(f"[serve_unified] Applied safetensors patch: {hack_path} -> {dest_path}") |
| 102 | + except Exception as e: |
| 103 | + print(f"[serve_unified] Warning: Failed to apply safetensors patch: {e}") |
| 104 | + |
| 105 | + |
| 106 | +def load_yaml_config(config_path: str) -> dict: |
| 107 | + """Load YAML config file.""" |
| 108 | + import yaml |
| 109 | + |
| 110 | + with open(config_path) as f: |
| 111 | + return yaml.safe_load(f) |
| 112 | + |
| 113 | + |
| 114 | +def build_config_from_args(args) -> dict: |
| 115 | + """Build config dict from CLI arguments (backward-compatible mode).""" |
| 116 | + config_dict = { |
| 117 | + "model_path": args.model, |
| 118 | + "device": args.device, |
| 119 | + "dtype": args.dtype, |
| 120 | + "max_new_tokens": args.max_new_tokens, |
| 121 | + "temperature": args.temperature, |
| 122 | + "top_p": args.top_p, |
| 123 | + } |
| 124 | + |
| 125 | + if args.codec_model: |
| 126 | + config_dict["codec_model_path"] = args.codec_model |
| 127 | + |
| 128 | + # Pass through any backend-specific args that were set |
| 129 | + for key in [ |
| 130 | + "top_k", |
| 131 | + "use_cfg", |
| 132 | + "cfg_scale", |
| 133 | + "decoder_only_model", |
| 134 | + "phoneme_input_type", |
| 135 | + "use_local_transformer", |
| 136 | + "hparams_file", |
| 137 | + "checkpoint_file", |
| 138 | + "legacy_codebooks", |
| 139 | + "legacy_text_conditioning", |
| 140 | + "hparams_from_wandb", |
| 141 | + "prompt_format", |
| 142 | + "ignore_system_prompt", |
| 143 | + "silence_padding_sec", |
| 144 | + "config_path", |
| 145 | + "llm_checkpoint_path", |
| 146 | + "tts_checkpoint_path", |
| 147 | + "speaker_reference", |
| 148 | + "num_frames_per_inference", |
| 149 | + ]: |
| 150 | + val = getattr(args, key, None) |
| 151 | + if val is not None: |
| 152 | + config_dict[key] = val |
| 153 | + |
| 154 | + # Handle store_true flags |
| 155 | + if getattr(args, "no_decode_audio", False): |
| 156 | + config_dict["decode_audio"] = False |
| 157 | + |
| 158 | + return config_dict |
| 159 | + |
| 160 | + |
| 161 | +def main(): |
| 162 | + parser = argparse.ArgumentParser( |
| 163 | + description="Unified NeMo Inference Server", |
| 164 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 165 | + ) |
| 166 | + |
| 167 | + # Primary: YAML config |
| 168 | + parser.add_argument("--config", default=None, help="Path to YAML config file") |
| 169 | + |
| 170 | + # Standard args for nemo-skills pipeline compatibility |
| 171 | + parser.add_argument("--model", default=None, help="Path to the model") |
| 172 | + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use") |
| 173 | + parser.add_argument("--port", type=int, default=8000, help="Server port") |
| 174 | + parser.add_argument("--host", default="0.0.0.0", help="Server host") |
| 175 | + parser.add_argument("--backend", default="magpie_tts", help="Backend type") |
| 176 | + |
| 177 | + # Server configuration |
| 178 | + parser.add_argument("--batch_size", type=int, default=8, help="Maximum batch size") |
| 179 | + parser.add_argument("--batch_timeout", type=float, default=0.1, help="Batch timeout in seconds") |
| 180 | + |
| 181 | + # Generation defaults |
| 182 | + parser.add_argument("--max_new_tokens", type=int, default=512, help="Max tokens to generate") |
| 183 | + parser.add_argument("--temperature", type=float, default=1.0, help="Generation temperature") |
| 184 | + parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling") |
| 185 | + |
| 186 | + # Model configuration |
| 187 | + parser.add_argument("--device", default="cuda", help="Device to use") |
| 188 | + parser.add_argument("--dtype", default="bfloat16", help="Model dtype") |
| 189 | + |
| 190 | + # Backend-specific options (for CLI backward compatibility) |
| 191 | + parser.add_argument("--codec_model", default=None, help="Path to codec model") |
| 192 | + parser.add_argument("--prompt_format", default=None, help="Prompt format") |
| 193 | + parser.add_argument("--phoneme_input_type", default=None, help="Phoneme input type") |
| 194 | + parser.add_argument("--decoder_only_model", action="store_true", default=None, help="Decoder-only model") |
| 195 | + parser.add_argument("--use_local_transformer", action="store_true", default=None, help="Local transformer") |
| 196 | + parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling") |
| 197 | + parser.add_argument("--use_cfg", action="store_true", default=None, help="Classifier-free guidance") |
| 198 | + parser.add_argument("--cfg_scale", type=float, default=None, help="CFG scale") |
| 199 | + parser.add_argument("--hparams_file", default=None, help="Path to hparams.yaml") |
| 200 | + parser.add_argument("--checkpoint_file", default=None, help="Path to .ckpt checkpoint") |
| 201 | + parser.add_argument("--legacy_codebooks", action="store_true", default=None, help="Legacy codebook indices") |
| 202 | + parser.add_argument("--legacy_text_conditioning", action="store_true", default=None, help="Legacy text conditioning") |
| 203 | + parser.add_argument("--hparams_from_wandb", action="store_true", default=None, help="hparams from wandb") |
| 204 | + parser.add_argument("--ignore_system_prompt", action="store_true", default=None, help="Ignore system prompts") |
| 205 | + parser.add_argument("--silence_padding_sec", type=float, default=None, help="Silence padding seconds") |
| 206 | + parser.add_argument("--config_path", default=None, help="Backend YAML config path") |
| 207 | + parser.add_argument("--llm_checkpoint_path", default=None, help="LLM checkpoint path") |
| 208 | + parser.add_argument("--tts_checkpoint_path", default=None, help="TTS checkpoint path") |
| 209 | + parser.add_argument("--speaker_reference", default=None, help="Speaker reference audio path") |
| 210 | + parser.add_argument("--num_frames_per_inference", type=int, default=None, help="Frames per inference") |
| 211 | + parser.add_argument("--no_decode_audio", action="store_true", help="Disable audio output") |
| 212 | + |
| 213 | + # Environment setup |
| 214 | + parser.add_argument("--code_path", default=None, help="Path to add to PYTHONPATH") |
| 215 | + parser.add_argument("--hack_path", default=None, help="Path to safetensors patch") |
| 216 | + |
| 217 | + # Debug |
| 218 | + parser.add_argument("--debug", action="store_true", help="Enable debug mode") |
| 219 | + |
| 220 | + args, extra_args = parser.parse_known_args() |
| 221 | + |
| 222 | + # Setup environment |
| 223 | + setup_pythonpath(args.code_path) |
| 224 | + apply_safetensors_patch(args.hack_path) |
| 225 | + |
| 226 | + if args.code_path: |
| 227 | + os.environ["UNIFIED_SERVER_CODE_PATH"] = args.code_path |
| 228 | + |
| 229 | + if args.debug: |
| 230 | + os.environ["DEBUG"] = "1" |
| 231 | + |
| 232 | + # Set CUDA devices |
| 233 | + if "CUDA_VISIBLE_DEVICES" not in os.environ: |
| 234 | + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(args.num_gpus)) |
| 235 | + |
| 236 | + # Build configuration |
| 237 | + if args.config: |
| 238 | + # YAML config mode |
| 239 | + config_dict = load_yaml_config(args.config) |
| 240 | + backend_type = config_dict.pop("backend", args.backend) |
| 241 | + # CLI overrides |
| 242 | + if args.model: |
| 243 | + config_dict["model_path"] = args.model |
| 244 | + else: |
| 245 | + # CLI args mode (backward compatible) |
| 246 | + if not args.model: |
| 247 | + parser.error("--model is required when not using --config") |
| 248 | + backend_type = args.backend |
| 249 | + config_dict = build_config_from_args(args) |
| 250 | + |
| 251 | + # Print configuration |
| 252 | + print("=" * 60) |
| 253 | + print("[serve_unified] Starting Unified NeMo Inference Server") |
| 254 | + print("=" * 60) |
| 255 | + print(f" Backend: {backend_type}") |
| 256 | + print(f" Model: {config_dict.get('model_path', 'N/A')}") |
| 257 | + print(f" Port: {args.port}") |
| 258 | + print(f" GPUs: {args.num_gpus}") |
| 259 | + print(f" Batch Size: {args.batch_size}") |
| 260 | + print(f" Batch Timeout: {args.batch_timeout}s") |
| 261 | + if args.config: |
| 262 | + print(f" Config: {args.config}") |
| 263 | + print("=" * 60) |
| 264 | + |
| 265 | + # Import and run |
| 266 | + try: |
| 267 | + import uvicorn |
| 268 | + |
| 269 | + from recipes.multimodal.server.unified_server import create_app |
| 270 | + |
| 271 | + app = create_app( |
| 272 | + backend_type=backend_type, |
| 273 | + config_dict=config_dict, |
| 274 | + batch_size=args.batch_size, |
| 275 | + batch_timeout=args.batch_timeout, |
| 276 | + ) |
| 277 | + |
| 278 | + uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
| 279 | + |
| 280 | + except ImportError as e: |
| 281 | + print(f"[serve_unified] Error: Failed to import unified server: {e}") |
| 282 | + print("[serve_unified] Make sure the recipes.multimodal.server package is in PYTHONPATH") |
| 283 | + sys.exit(1) |
| 284 | + except Exception as e: |
| 285 | + print(f"[serve_unified] Error: {e}") |
| 286 | + import traceback |
| 287 | + |
| 288 | + traceback.print_exc() |
| 289 | + sys.exit(1) |
| 290 | + |
| 291 | + |
| 292 | +if __name__ == "__main__": |
| 293 | + main() |
0 commit comments