This guide covers how to use SGL-JAX for multimodal inference with models like Wan 2.1.
SGL-JAX provides a unified, high-performance inference framework for multimodal models. The framework supports heterogeneous compute patterns—integrating Auto-Regressive (AR) decoding with Diffusion denoising—within a single pipeline.
For architecture details, see the RFC: Multimodal Architecture Design.
| Model | Description |
|---|---|
| Wan-AI/Wan2.1-T2V-1.3B-Diffusers | Video generation model supporting text-to-video generation |
| Wan-AI/Wan2.1-T2V-14B-Diffusers | Video generation model supporting text-to-video generation |
| Wan-AI/Wan2.2-T2V-A14B-Diffusers | Video generation model supporting text-to-video generation |
| Qwen/Qwen2.5-VL | Vision-language model with 3B/7B/32B/72B parameter sizes. |
Still Under Development
SGL-JAX provides an OpenAI-compatible API for online inference.
uv run python3 -u -m sgl_jax.launch_server \
--multimodal \
--model-path=Wan-AI/Wan2.1-T2V-14B-Diffusers \
--log-requestscurl http://localhost:30000/api/v1/images/generation \
-H "Content-Type: application/json" \
-d '{"prompt": "A curious raccoon", "size": "480*832"}'curl http://localhost:30000/api/v1/videos/generation \
-H "Content-Type: application/json" \
-d '{"prompt": "A curious raccoon", "size": "480*832", "num_frames": 41}'Multimodal models are composed of multiple stages (e.g., ViT, Diffusion, AR). Each stage can be configured independently.
If not provided, the default config from
python/sgl_jax/srt/multimodal/models/static_configswill be used.
stage_args:
- stage_id: 0
run_time:
num_tpus: 2
sharding_spec: ["tensor"]
launch_args:
attention_backend: fa
tp_size: 2
input_type: image
output_type: tensor- Independent Scheduler: Each Stage has its own scheduler to maximize TPU utilization
- Stage Overlap: The framework automatically overlaps computation across different stages
- Memory Management: Each stage maintains its own memory pool for efficient cache management