1+ import logging
2+ import time
13from collections .abc import Callable
24
35import jax
6+ import numpy as np
7+ from jax .sharding import NamedSharding
8+ from jax .sharding import PartitionSpec as P
9+ from tqdm import tqdm
410
511from sgl_jax .srt .multimodal .common .ServerArgs import MultimodalServerArgs
12+ from sgl_jax .srt .multimodal .configs .config_registry import get_diffusion_config
613from sgl_jax .srt .multimodal .manager .schedule_batch import Req
714from sgl_jax .srt .multimodal .model_executor .diffusion .diffusion_model_runner import (
815 DiffusionModelRunner ,
916)
17+ from sgl_jax .srt .utils .jax_utils import device_array
18+
19+ logger = logging .getLogger (__name__ )
1020
1121
1222class DiffusionModelWorker :
@@ -22,6 +32,9 @@ def __init__(
2232 server_args , self .mesh , model_class = model_class , stage_sub_dir = stage_sub_dir
2333 )
2434 self .initialize ()
35+ self .precompile_width_heights = server_args .precompile_width_heights
36+ self .precompile_frame_paddings = server_args .precompile_frame_paddings
37+ self .model_config = get_diffusion_config (server_args .model_path )
2538
2639 def initialize (self ):
2740 pass
@@ -50,3 +63,38 @@ def forward(
5063 return self .model_runner .forward (
5164 batch , mesh , abort_checker = abort_checker , step_callback = step_callback
5265 )
66+
67+ def run_precompile (self ):
68+ self .precompile ()
69+
70+ def precompile (self ):
71+ start_time = time .perf_counter ()
72+ logger .info (
73+ "[DIFFUSION] Begin to precompile width*height=%s" ,
74+ self .precompile_width_heights ,
75+ )
76+
77+ with tqdm (
78+ self .precompile_width_heights , desc = "[DIFFUSION] PRECOMPILE" , leave = False
79+ ) as pbar :
80+ for wh in pbar :
81+ whs = wh .split ("*" )
82+ width , height = int (whs [0 ]), int (whs [1 ])
83+ assert width % self .model_config .scale_factor_spatial == 0
84+ assert height % self .model_config .scale_factor_spatial == 0
85+ for t in self .precompile_frame_paddings :
86+ pbar .set_postfix (wh = wh , t = t )
87+ embeds = np .random .random ((2 , 512 , self .model_config .text_dim ))
88+ embeds = device_array (embeds , sharding = NamedSharding (self .mesh , P ()))
89+ req = Req (
90+ prompt_embeds = embeds [0 ],
91+ negative_prompt_embeds = embeds [1 ],
92+ do_classifier_free_guidance = True ,
93+ width = width ,
94+ height = height ,
95+ num_frames = t ,
96+ num_inference_steps = 1 ,
97+ )
98+ self .model_runner .forward (req , self .mesh )
99+ end_time = time .perf_counter ()
100+ logger .info ("[DIFFUSION] Precompile finished in %.0f secs" , end_time - start_time )
0 commit comments