Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ license = { file = "LICENSE" }

dependencies = [
"fastapi~=0.116.1",
"flax~=0.12.0",
"flax==0.12.4",
"huggingface-hub~=0.34.3",
"jinja2~=3.1.6",
"llguidance~=1.3.0",
Expand Down
22 changes: 11 additions & 11 deletions python/sgl_jax/srt/multimodal/common/ServerArgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class MultimodalServerArgs(ServerArgs):
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))
image_encoder_precision: str = "bf16"

vae_decode_precompile_width_height: list[str] | None = None
vae_decode_precompile_frame_paddings: list[int] | None = None
precompile_width_heights: list[str] | None = None
precompile_frame_paddings: list[int] | None = None

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -95,14 +95,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
)

parser.add_argument(
"--vae-decode-precompile-width-height",
"--precompile-width-heights",
type=str,
nargs="+",
help="Set the list of width and height for jax jit, format width*height",
)

parser.add_argument(
"--vae-decode-precompile-frame-paddings",
"--precompile-frame-paddings",
type=int,
nargs="+",
help="Set the frame count list for jax jit",
Expand All @@ -115,17 +115,17 @@ def __post_init__(self):
# manually.
super().__post_init__()

if self.vae_decode_precompile_width_height is not None:
for wh in self.vae_decode_precompile_width_height:
if self.precompile_width_heights is not None:
for wh in self.precompile_width_heights:
if len(wh.split("*")) < 2:
raise Exception("Width and height must be connected with an asterisk *.")
if self.vae_decode_precompile_frame_paddings is None:
self.vae_decode_precompile_frame_paddings = [1]
if self.precompile_frame_paddings is None:
self.precompile_frame_paddings = [1]
else:
self.vae_decode_precompile_frame_paddings.sort()
self.precompile_frame_paddings.sort()
else:
self.vae_decode_precompile_width_height = ["480*832"]
self.vae_decode_precompile_frame_paddings = [1]
self.precompile_width_heights = ["480*832"]
self.precompile_frame_paddings = [1]

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
8 changes: 0 additions & 8 deletions python/sgl_jax/srt/multimodal/entrypoint/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,14 +554,6 @@ def _execute_multimodal_server_warmup(
],
"max_tokens": 3,
}
elif _is_wan_model(server_args.model_path):
request_endpoint = "/api/v1/images/generation"
json_data = {
"prompt": "warmup request",
"size": "480*832",
"num_inference_steps": 2,
"save_output": False,
}
elif "MiMo-Audio" in server_args.model_path:
request_endpoint = "/v1/audio/transcriptions"
# audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(
self.aborted_rids: set[str] = set()
# Current request being processed (for abort checking during steps)
self._current_rid: str | None = None
if not server_args.disable_precompile:
logger.info("[Diffusion Scheduler] Begins to run diffusion worker precompile.")
self.diffusion_worker.run_precompile()
logger.info("[Diffusion Scheduler] Completes diffusion worker precompile.")

def event_loop_normal(self):
"""Blocking event loop for processing incoming diffusion requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def preprocess(self, req):
req.latents += self.model_config.shift_factor
req.latents = jax.device_get(req.latents)
latents_t_padding = 0
if self.server_args.vae_decode_precompile_frame_paddings is not None and hasattr(
if self.server_args.precompile_frame_paddings is not None and hasattr(
self.model_config, "scale_factor_temporal"
):
for n_frame in self.server_args.vae_decode_precompile_frame_paddings:
for n_frame in self.server_args.precompile_frame_paddings:
latents_t = (n_frame - 1) // self.model_config.scale_factor_temporal + 1
if latents_t >= req.latents.shape[1]:
latents_t_padding = latents_t - req.latents.shape[1]
Expand All @@ -153,7 +153,8 @@ def run_vae_batch(self, batch: list[Req]):

for req in batch:
output, cache_miss = self.vae_worker.forward(req)
logger.info("VAE forward pass cache miss: %s", cache_miss)
if cache_miss > 0:
logger.info("VAE forward pass cache miss: %s", cache_miss)
req.output = jax.device_get(output[:, : req.num_frames, :, :, :])
req.latents = None
self.forward_ct += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def pad_to_512(embeds):
encoder_hidden_states_image=None,
guidance_scale=None,
)
logger.info("diffusion cache miss count: %d", count())
if count() > 0:
logger.info("diffusion cache miss count: %d", count())
if do_classifier_free_guidance:
bsz = latents.shape[0] // 2
noise_uncond = noise_pred[bsz:]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import logging
import time
from collections.abc import Callable

import jax
import numpy as np
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from tqdm import tqdm

from sgl_jax.srt.multimodal.common.ServerArgs import MultimodalServerArgs
from sgl_jax.srt.multimodal.configs.config_registry import get_diffusion_config
from sgl_jax.srt.multimodal.manager.schedule_batch import Req
from sgl_jax.srt.multimodal.model_executor.diffusion.diffusion_model_runner import (
DiffusionModelRunner,
)
from sgl_jax.srt.utils.jax_utils import device_array

logger = logging.getLogger(__name__)


class DiffusionModelWorker:
Expand All @@ -22,6 +32,9 @@ def __init__(
server_args, self.mesh, model_class=model_class, stage_sub_dir=stage_sub_dir
)
self.initialize()
self.precompile_width_heights = server_args.precompile_width_heights
self.precompile_frame_paddings = server_args.precompile_frame_paddings
self.model_config = get_diffusion_config(server_args.model_path)

def initialize(self):
pass
Expand Down Expand Up @@ -50,3 +63,38 @@ def forward(
return self.model_runner.forward(
batch, mesh, abort_checker=abort_checker, step_callback=step_callback
)

def run_precompile(self):
self.precompile()

def precompile(self):
start_time = time.perf_counter()
logger.info(
"[DIFFUSION] Begin to precompile width*height=%s",
self.precompile_width_heights,
)

with tqdm(
self.precompile_width_heights, desc="[DIFFUSION] PRECOMPILE", leave=False
) as pbar:
for wh in pbar:
whs = wh.split("*")
width, height = int(whs[0]), int(whs[1])
assert width % self.model_config.scale_factor_spatial == 0
assert height % self.model_config.scale_factor_spatial == 0
for t in self.precompile_frame_paddings:
pbar.set_postfix(wh=wh, t=t)
embeds = np.random.random((2, 512, self.model_config.text_dim))
embeds = device_array(embeds, sharding=NamedSharding(self.mesh, P()))
req = Req(
prompt_embeds=embeds[0],
negative_prompt_embeds=embeds[1],
do_classifier_free_guidance=True,
width=width,
height=height,
num_frames=t,
num_inference_steps=1,
)
self.model_runner.forward(req, self.mesh)
end_time = time.perf_counter()
logger.info("[DIFFUSION] Precompile finished in %.0f secs", end_time - start_time)
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(
server_args, mesh, model_class=model_class, stage_sub_dir=stage_sub_dir
)
self.server_args = server_args
self.vae_decode_precompile_width_height = server_args.vae_decode_precompile_width_height
self.vae_decode_precompile_frame_paddings = server_args.vae_decode_precompile_frame_paddings
self.precompile_width_heights = server_args.precompile_width_heights
self.precompile_frame_paddings = server_args.precompile_frame_paddings
self.model_config = get_vae_config(self.server_args.model_path)
# Initialize model here based on model_config

Expand All @@ -46,23 +46,23 @@ def decode_precompile(self):
start_time = time.perf_counter()
logger.info(
"[VAE DECODE] Begin to precompile width*height=%s",
self.vae_decode_precompile_width_height,
self.precompile_width_heights,
)

with tqdm(
self.vae_decode_precompile_width_height, desc="[VAE DECODE] PRECOMPILE", leave=False
self.precompile_width_heights, desc="[VAE DECODE] PRECOMPILE", leave=False
) as pbar:
for wh in pbar:
whs = wh.split("*")
width, height = int(whs[0]), int(whs[1])
assert width % self.model_config.scale_factor_spatial == 0
assert height % self.model_config.scale_factor_spatial == 0
for t in self.vae_decode_precompile_frame_paddings:
for t in self.precompile_frame_paddings:
pbar.set_postfix(wh=wh, t=t)
latents_cpu = np.random.random(
(
1,
t // self.model_config.scale_factor_temporal + 1,
(t - 1) // self.model_config.scale_factor_temporal + 1,
height // self.model_config.scale_factor_spatial,
width // self.model_config.scale_factor_spatial,
self.model_config.z_dim,
Expand Down