Skip to content

Commit f3cc742

Browse files
committed
Revise the code based on the review feedback
1 parent 2b9aca6 commit f3cc742

File tree

3 files changed

+56
-65
lines changed

3 files changed

+56
-65
lines changed

python/sglang/multimodal_gen/benchmarks/bench_serving.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import argparse
1919
import asyncio
20+
import copy
2021
import glob
2122
import json
2223
import os
@@ -314,8 +315,8 @@ def __getitem__(self, idx: int) -> RequestFuncInput:
314315
height=self.args.height,
315316
num_frames=self.args.num_frames,
316317
fps=self.args.fps,
317-
num_inference_steps=getattr(self.args, "num_inference_steps", None),
318-
guidance_scale=getattr(self.args, "guidance_scale", None),
318+
num_inference_steps=self.args.num_inference_steps,
319+
guidance_scale=self.args.guidance_scale,
319320
image_paths=image_paths,
320321
)
321322

@@ -374,7 +375,7 @@ async def async_request_image_sglang(
374375
data.add_field("guidance_scale", str(input.guidance_scale))
375376

376377
# Add profiling and other extra parameters
377-
extra_params = input.extra_body.copy()
378+
extra_params = copy.deepcopy(input.extra_body)
378379
if extra_params.pop("profile", None):
379380
data.add_field("profile", "true")
380381
for key, value in extra_params.items():
@@ -766,9 +767,9 @@ async def limited_request_func(req, session, pbar):
766767
api_url=f"{args.base_url}/start_profile"
767768
)
768769
if profile_output.success:
769-
print(f"Profiler started: {profile_output.message}")
770+
logger.info(f"Profiler started: {profile_output.message}")
770771
else:
771-
print(f"Warning: Failed to start profiler: {profile_output.error}")
772+
logger.warning(f"Failed to start profiler: {profile_output.error}")
772773

773774
# Run benchmark
774775
pbar = tqdm(total=len(requests_list), disable=args.disable_tqdm)
@@ -792,77 +793,81 @@ async def limited_request_func(req, session, pbar):
792793

793794
# Stop profiler if it was started
794795
if args.profile:
795-
print("Stopping profiler and saving traces...")
796+
logger.info("Stopping profiler and saving traces...")
796797
profile_output = await async_request_profile(
797798
api_url=f"{args.base_url}/stop_profile"
798799
)
799800
if profile_output.success:
800-
print(f"Profiler stopped: {profile_output.message}")
801+
logger.info(f"Profiler stopped: {profile_output.message}")
801802
else:
802-
print(f"Warning: Failed to stop profiler: {profile_output.error}")
803+
logger.warning(f"Failed to stop profiler: {profile_output.error}")
803804

804805
# Calculate metrics
805806
metrics = calculate_metrics(outputs, total_duration)
806807

807-
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=60, c="="))
808+
logger.info("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=60, c="="))
808809

809810
# Section 1: Configuration
810-
print("{:<40} {:<15}".format("Task:", task_name))
811-
print("{:<40} {:<15}".format("Model:", args.model))
812-
print("{:<40} {:<15}".format("Dataset:", args.dataset))
811+
logger.info("{:<40} {:<15}".format("Task:", task_name))
812+
logger.info("{:<40} {:<15}".format("Model:", args.model))
813+
logger.info("{:<40} {:<15}".format("Dataset:", args.dataset))
813814

814815
# Section 2: Execution & Traffic
815-
print(f"{'-' * 50}")
816-
print("{:<40} {:<15.2f}".format("Benchmark duration (s):", metrics["duration"]))
817-
print("{:<40} {:<15}".format("Request rate:", str(args.request_rate)))
818-
print(
816+
logger.info(f"{'-' * 50}")
817+
logger.info(
818+
"{:<40} {:<15.2f}".format("Benchmark duration (s):", metrics["duration"])
819+
)
820+
logger.info("{:<40} {:<15}".format("Request rate:", str(args.request_rate)))
821+
logger.info(
819822
"{:<40} {:<15}".format(
820823
"Max request concurrency:",
821824
str(args.max_concurrency) if args.max_concurrency else "not set",
822825
)
823826
)
824-
print(
827+
logger.info(
825828
"{:<40} {}/{:<15}".format(
826829
"Successful requests:", metrics["completed_requests"], len(requests_list)
827830
)
828831
)
829832

830833
# Section 3: Performance Metrics
831-
print(f"{'-' * 50}")
834+
logger.info(f"{'-' * 50}")
832835

833-
print(
836+
logger.info(
834837
"{:<40} {:<15.2f}".format(
835838
"Request throughput (req/s):", metrics["throughput_qps"]
836839
)
837840
)
838-
print("{:<40} {:<15.4f}".format("Latency Mean (s):", metrics["latency_mean"]))
839-
print("{:<40} {:<15.4f}".format("Latency Median (s):", metrics["latency_median"]))
840-
print("{:<40} {:<15.4f}".format("Latency P99 (s):", metrics["latency_p99"]))
841+
logger.info("{:<40} {:<15.4f}".format("Latency Mean (s):", metrics["latency_mean"]))
842+
logger.info(
843+
"{:<40} {:<15.4f}".format("Latency Median (s):", metrics["latency_median"])
844+
)
845+
logger.info("{:<40} {:<15.4f}".format("Latency P99 (s):", metrics["latency_p99"]))
841846

842847
if metrics["peak_memory_mb_max"] > 0:
843-
print(f"{'-' * 50}")
844-
print(
848+
logger.info(f"{'-' * 50}")
849+
logger.info(
845850
"{:<40} {:<15.2f}".format(
846851
"Peak Memory Max (MB):", metrics["peak_memory_mb_max"]
847852
)
848853
)
849-
print(
854+
logger.info(
850855
"{:<40} {:<15.2f}".format(
851856
"Peak Memory Mean (MB):", metrics["peak_memory_mb_mean"]
852857
)
853858
)
854-
print(
859+
logger.info(
855860
"{:<40} {:<15.2f}".format(
856861
"Peak Memory Median (MB):", metrics["peak_memory_mb_median"]
857862
)
858863
)
859864

860-
print("=" * 60)
865+
logger.info("=" * 60)
861866

862867
if args.output_file:
863868
with open(args.output_file, "w") as f:
864869
json.dump(metrics, f, indent=2)
865-
print(f"Metrics saved to {args.output_file}")
870+
logger.info(f"Metrics saved to {args.output_file}")
866871

867872

868873
if __name__ == "__main__":

python/sglang/multimodal_gen/runtime/entrypoints/http_server.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,23 @@
1313
from pydantic import BaseModel
1414

1515
from sglang.multimodal_gen.configs.sample.sampling_params import SamplingParams
16+
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_world_rank
1617
from sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api
1718
from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (
1819
VertexGenerateReqInput,
1920
)
21+
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import (
22+
StartProfileReq,
23+
StopProfileReq,
24+
)
2025
from sglang.multimodal_gen.runtime.entrypoints.utils import (
2126
post_process_sample,
2227
prepare_request,
2328
)
2429
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client
2530
from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args
26-
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var
2731
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
28-
from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler
32+
from sglang.srt.environ import envs
2933

3034
logger = init_logger(__name__)
3135

@@ -146,30 +150,22 @@ async def start_profile(request: Request, obj: Optional[ProfileReqInput] = None)
146150
if obj is None:
147151
obj = ProfileReqInput()
148152

149-
output_dir = obj.output_dir or os.getenv("SGLANG_TORCH_PROFILER_DIR", "./logs")
153+
output_dir = obj.output_dir or envs.SGLANG_TORCH_PROFILER_DIR.get()
150154

151-
# Generate unified profile_id (similar to LLM implementation)
152155
profile_id = str(int(time_module.time()))
153156

154-
# Read env vars for with_stack and record_shapes
155-
env_with_stack = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "false")
156-
env_record_shapes = get_bool_env_var("SGLANG_PROFILE_RECORD_SHAPES", "false")
157-
158-
with_stack = obj.with_stack if obj.with_stack is not None else env_with_stack
159-
record_shapes = (
160-
obj.record_shapes if obj.record_shapes is not None else env_record_shapes
161-
)
162-
163-
# 1. Start profiler in HTTP Server process
164-
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
165-
get_world_rank,
166-
)
157+
with_stack = obj.with_stack or envs.SGLANG_PROFILE_WITH_STACK.get()
158+
record_shapes = obj.record_shapes or envs.SGLANG_PROFILE_RECORD_SHAPES.get()
167159

168160
try:
169161
rank = get_world_rank()
170162
except Exception:
163+
logger.warning("Failed to get world rank, defaulting to 0")
171164
rank = 0
172165

166+
# Lazy import to reduce import time (see issue #10492)
167+
from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler
168+
173169
http_profiler = SGLDiffusionProfiler(
174170
request_id=profile_id,
175171
rank=rank,
@@ -185,11 +181,6 @@ async def start_profile(request: Request, obj: Optional[ProfileReqInput] = None)
185181
_global_profiler_state["profiler"] = http_profiler
186182
_global_profiler_state["profile_id"] = profile_id
187183

188-
# 2. Start profiler in GPU Worker process via ZMQ
189-
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import (
190-
StartProfileReq,
191-
)
192-
193184
start_req = StartProfileReq(
194185
output_dir=output_dir,
195186
profile_id=profile_id,
@@ -241,11 +232,6 @@ async def stop_profile():
241232
if profiler is not None:
242233
profiler.stop(export_trace=True, dump_rank=None) # Save for all ranks
243234

244-
# 2. Stop profiler in GPU Worker process via ZMQ
245-
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import (
246-
StopProfileReq,
247-
)
248-
249235
stop_req = StopProfileReq(export_trace=True)
250236
try:
251237
response = await async_scheduler_client.forward(stop_req)

python/sglang/multimodal_gen/runtime/utils/profiler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import gzip
22
import os
3+
from pathlib import Path
34

45
import torch
56

7+
from sglang.multimodal_gen.runtime.platforms import current_platform
68
from sglang.multimodal_gen.runtime.utils.logging_utils import CYAN, RESET, init_logger
7-
from sglang.srt.utils import get_bool_env_var
9+
from sglang.srt.environ import envs
810

911
logger = init_logger(__name__)
1012

@@ -38,12 +40,10 @@ def __init__(
3840
self.full_profile = full_profile
3941
self.is_host = is_host
4042

41-
# Use environment variables with fallback to parameters
42-
self.log_dir = log_dir or os.getenv("SGLANG_TORCH_PROFILER_DIR", "./logs")
43+
self.log_dir = log_dir or envs.SGLANG_TORCH_PROFILER_DIR.get()
4344

44-
# Read from environment variables, allow parameter override
45-
env_with_stack = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "false")
46-
env_record_shapes = get_bool_env_var("SGLANG_PROFILE_RECORD_SHAPES", "false")
45+
env_with_stack = envs.SGLANG_PROFILE_WITH_STACK.get()
46+
env_record_shapes = envs.SGLANG_PROFILE_RECORD_SHAPES.get()
4747

4848
self.with_stack = with_stack if with_stack is not None else env_with_stack
4949
self.record_shapes = (
@@ -107,7 +107,7 @@ def _resolve_activities(
107107

108108
def _default() -> list[torch.profiler.ProfilerActivity]:
109109
ret = [torch.profiler.ProfilerActivity.CPU]
110-
if torch.cuda.is_available():
110+
if current_platform.is_cuda_alike():
111111
ret.append(torch.profiler.ProfilerActivity.CUDA)
112112
return ret
113113

@@ -123,7 +123,7 @@ def _default() -> list[torch.profiler.ProfilerActivity]:
123123
if s == "cpu":
124124
use_cpu = True
125125
elif s in ("gpu", "cuda"):
126-
if torch.cuda.is_available():
126+
if current_platform.is_cuda_alike():
127127
use_cuda = True
128128
else:
129129
logger.warning(
@@ -169,7 +169,7 @@ def stop(self, export_trace: bool = True, dump_rank: int | None = None):
169169
return
170170
self.has_stopped = True
171171
logger.info("Stopping Profiler...")
172-
if torch.cuda.is_available():
172+
if current_platform.is_cuda_alike():
173173
torch.cuda.synchronize()
174174
self.profiler.stop()
175175

@@ -194,7 +194,7 @@ def _export_trace(self):
194194
else:
195195
filename = f"{self.request_id}-rank-{self.rank}.trace.json.gz"
196196

197-
trace_path = os.path.abspath(os.path.join(self.log_dir, filename))
197+
trace_path = str(Path(self.log_dir, filename).resolve())
198198
self.profiler.export_chrome_trace(trace_path)
199199

200200
if self._check_trace_integrity(trace_path):

0 commit comments

Comments
 (0)