Skip to content

Commit 57f6ff6

Browse files
whzwhz
authored andcommitted
precompile for diffusion
1 parent 8655866 commit 57f6ff6

8 files changed

Lines changed: 76 additions & 30 deletions

File tree

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ license = { file = "LICENSE" }
1212

1313
dependencies = [
1414
"fastapi~=0.116.1",
15-
"flax~=0.12.0",
15+
"flax==0.12.4",
1616
"huggingface-hub~=0.34.3",
1717
"jinja2~=3.1.6",
1818
"llguidance~=1.3.0",

python/sgl_jax/srt/multimodal/common/ServerArgs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class MultimodalServerArgs(ServerArgs):
1919
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))
2020
image_encoder_precision: str = "bf16"
2121

22-
vae_decode_precompile_width_height: list[str] | None = None
23-
vae_decode_precompile_frame_paddings: list[int] | None = None
22+
precompile_width_heights: list[str] | None = None
23+
precompile_frame_paddings: list[int] | None = None
2424

2525
@staticmethod
2626
def add_cli_args(parser: argparse.ArgumentParser):
@@ -95,14 +95,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
9595
)
9696

9797
parser.add_argument(
98-
"--vae-decode-precompile-width-height",
98+
"--precompile-width-heights",
9999
type=str,
100100
nargs="+",
101101
help="Set the list of width and height for jax jit, format width*height",
102102
)
103103

104104
parser.add_argument(
105-
"--vae-decode-precompile-frame-paddings",
105+
"--precompile-frame-paddings",
106106
type=int,
107107
nargs="+",
108108
help="Set the frame count list for jax jit",
@@ -115,17 +115,17 @@ def __post_init__(self):
115115
# manually.
116116
super().__post_init__()
117117

118-
if self.vae_decode_precompile_width_height is not None:
119-
for wh in self.vae_decode_precompile_width_height:
118+
if self.precompile_width_heights is not None:
119+
for wh in self.precompile_width_heights:
120120
if len(wh.split("*")) < 2:
121121
raise Exception("Width and height must be connected with an asterisk *.")
122-
if self.vae_decode_precompile_frame_paddings is None:
123-
self.vae_decode_precompile_frame_paddings = [1]
122+
if self.precompile_frame_paddings is None:
123+
self.precompile_frame_paddings = [1]
124124
else:
125-
self.vae_decode_precompile_frame_paddings.sort()
125+
self.precompile_frame_paddings.sort()
126126
else:
127-
self.vae_decode_precompile_width_height = ["480*832"]
128-
self.vae_decode_precompile_frame_paddings = [1]
127+
self.precompile_width_heights = ["480*832"]
128+
self.precompile_frame_paddings = [1]
129129

130130
@classmethod
131131
def from_cli_args(cls, args: argparse.Namespace):

python/sgl_jax/srt/multimodal/entrypoint/http_server.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,6 @@ def _execute_multimodal_server_warmup(
554554
],
555555
"max_tokens": 3,
556556
}
557-
elif _is_wan_model(server_args.model_path):
558-
request_endpoint = "/api/v1/images/generation"
559-
json_data = {
560-
"prompt": "warmup request",
561-
"size": "480*832",
562-
"num_inference_steps": 2,
563-
"save_output": False,
564-
}
565557
elif "MiMo-Audio" in server_args.model_path:
566558
request_endpoint = "/v1/audio/transcriptions"
567559
# audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/WhDJDIviAOg_120_10.mp3"

python/sgl_jax/srt/multimodal/manager/scheduler/diffusion_scheduler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(
5959
self.aborted_rids: set[str] = set()
6060
# Current request being processed (for abort checking during steps)
6161
self._current_rid: str | None = None
62+
if not server_args.disable_precompile:
63+
logger.info("[Diffusion Scheduler] Begins to run diffusion worker precompile.")
64+
self.diffusion_worker.run_precompile()
65+
logger.info("[Diffusion Scheduler] Completes diffusion worker precompile.")
6266

6367
def event_loop_normal(self):
6468
"""Blocking event loop for processing incoming diffusion requests.

python/sgl_jax/srt/multimodal/manager/scheduler/vae_scheduler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ def preprocess(self, req):
127127
req.latents += self.model_config.shift_factor
128128
req.latents = jax.device_get(req.latents)
129129
latents_t_padding = 0
130-
if self.server_args.vae_decode_precompile_frame_paddings is not None and hasattr(
130+
if self.server_args.precompile_frame_paddings is not None and hasattr(
131131
self.model_config, "scale_factor_temporal"
132132
):
133-
for n_frame in self.server_args.vae_decode_precompile_frame_paddings:
133+
for n_frame in self.server_args.precompile_frame_paddings:
134134
latents_t = (n_frame - 1) // self.model_config.scale_factor_temporal + 1
135135
if latents_t >= req.latents.shape[1]:
136136
latents_t_padding = latents_t - req.latents.shape[1]
@@ -153,7 +153,8 @@ def run_vae_batch(self, batch: list[Req]):
153153

154154
for req in batch:
155155
output, cache_miss = self.vae_worker.forward(req)
156-
logger.info("VAE forward pass cache miss: %s", cache_miss)
156+
if cache_miss > 0:
157+
logger.info("VAE forward pass cache miss: %s", cache_miss)
157158
req.output = jax.device_get(output[:, : req.num_frames, :, :, :])
158159
req.latents = None
159160
self.forward_ct += 1

python/sgl_jax/srt/multimodal/model_executor/diffusion/diffusion_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def pad_to_512(embeds):
222222
encoder_hidden_states_image=None,
223223
guidance_scale=None,
224224
)
225-
logger.info("diffusion cache miss count: %d", count())
225+
if count() > 0:
226+
logger.info("diffusion cache miss count: %d", count())
226227
if do_classifier_free_guidance:
227228
bsz = latents.shape[0] // 2
228229
noise_uncond = noise_pred[bsz:]

python/sgl_jax/srt/multimodal/model_executor/diffusion/diffusion_model_worker.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
import logging
2+
import time
13
from collections.abc import Callable
24

35
import jax
6+
import numpy as np
7+
from jax.sharding import NamedSharding
8+
from jax.sharding import PartitionSpec as P
9+
from tqdm import tqdm
410

511
from sgl_jax.srt.multimodal.common.ServerArgs import MultimodalServerArgs
12+
from sgl_jax.srt.multimodal.configs.config_registry import get_diffusion_config
613
from sgl_jax.srt.multimodal.manager.schedule_batch import Req
714
from sgl_jax.srt.multimodal.model_executor.diffusion.diffusion_model_runner import (
815
DiffusionModelRunner,
916
)
17+
from sgl_jax.srt.utils.jax_utils import device_array
18+
19+
logger = logging.getLogger(__name__)
1020

1121

1222
class DiffusionModelWorker:
@@ -22,6 +32,9 @@ def __init__(
2232
server_args, self.mesh, model_class=model_class, stage_sub_dir=stage_sub_dir
2333
)
2434
self.initialize()
35+
self.precompile_width_heights = server_args.precompile_width_heights
36+
self.precompile_frame_paddings = server_args.precompile_frame_paddings
37+
self.model_config = get_diffusion_config(server_args.model_path)
2538

2639
def initialize(self):
2740
pass
@@ -50,3 +63,38 @@ def forward(
5063
return self.model_runner.forward(
5164
batch, mesh, abort_checker=abort_checker, step_callback=step_callback
5265
)
66+
67+
def run_precompile(self):
68+
self.precompile()
69+
70+
def precompile(self):
71+
start_time = time.perf_counter()
72+
logger.info(
73+
"[DIFFUSION] Begin to precompile width*height=%s",
74+
self.precompile_width_heights,
75+
)
76+
77+
with tqdm(
78+
self.precompile_width_heights, desc="[DIFFUSION] PRECOMPILE", leave=False
79+
) as pbar:
80+
for wh in pbar:
81+
whs = wh.split("*")
82+
width, height = int(whs[0]), int(whs[1])
83+
assert width % self.model_config.scale_factor_spatial == 0
84+
assert height % self.model_config.scale_factor_spatial == 0
85+
for t in self.precompile_frame_paddings:
86+
pbar.set_postfix(wh=wh, t=t)
87+
embeds = np.random.random((2, 512, self.model_config.text_dim))
88+
embeds = device_array(embeds, sharding=NamedSharding(self.mesh, P()))
89+
req = Req(
90+
prompt_embeds=embeds[0],
91+
negative_prompt_embeds=embeds[1],
92+
do_classifier_free_guidance=True,
93+
width=width,
94+
height=height,
95+
num_frames=t,
96+
num_inference_steps=1,
97+
)
98+
self.model_runner.forward(req, self.mesh)
99+
end_time = time.perf_counter()
100+
logger.info("[DIFFUSION] Precompile finished in %.0f secs", end_time - start_time)

python/sgl_jax/srt/multimodal/model_executor/vae/vae_model_worker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(
2828
server_args, mesh, model_class=model_class, stage_sub_dir=stage_sub_dir
2929
)
3030
self.server_args = server_args
31-
self.vae_decode_precompile_width_height = server_args.vae_decode_precompile_width_height
32-
self.vae_decode_precompile_frame_paddings = server_args.vae_decode_precompile_frame_paddings
31+
self.precompile_width_heights = server_args.precompile_width_heights
32+
self.precompile_frame_paddings = server_args.precompile_frame_paddings
3333
self.model_config = get_vae_config(self.server_args.model_path)
3434
# Initialize model here based on model_config
3535

@@ -46,23 +46,23 @@ def decode_precompile(self):
4646
start_time = time.perf_counter()
4747
logger.info(
4848
"[VAE DECODE] Begin to precompile width*height=%s",
49-
self.vae_decode_precompile_width_height,
49+
self.precompile_width_heights,
5050
)
5151

5252
with tqdm(
53-
self.vae_decode_precompile_width_height, desc="[VAE DECODE] PRECOMPILE", leave=False
53+
self.precompile_width_heights, desc="[VAE DECODE] PRECOMPILE", leave=False
5454
) as pbar:
5555
for wh in pbar:
5656
whs = wh.split("*")
5757
width, height = int(whs[0]), int(whs[1])
5858
assert width % self.model_config.scale_factor_spatial == 0
5959
assert height % self.model_config.scale_factor_spatial == 0
60-
for t in self.vae_decode_precompile_frame_paddings:
60+
for t in self.precompile_frame_paddings:
6161
pbar.set_postfix(wh=wh, t=t)
6262
latents_cpu = np.random.random(
6363
(
6464
1,
65-
t // self.model_config.scale_factor_temporal + 1,
65+
(t - 1) // self.model_config.scale_factor_temporal + 1,
6666
height // self.model_config.scale_factor_spatial,
6767
width // self.model_config.scale_factor_spatial,
6868
self.model_config.z_dim,

0 commit comments

Comments
 (0)