Skip to content

Commit 18da436

Browse files
committed
add
1 parent 38ee749 commit 18da436

File tree

6 files changed

+346
-55
lines changed

6 files changed

+346
-55
lines changed

python/sglang/multimodal_gen/runtime/entrypoints/http_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from sglang.multimodal_gen.runtime.scheduler_client import async_scheduler_client
2626
from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args
2727
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
28+
from sglang.srt.utils import (
29+
add_prometheus_middleware,
30+
add_prometheus_track_response_middleware,
31+
)
2832

2933
if TYPE_CHECKING:
3034
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
@@ -207,6 +211,10 @@ def create_app(server_args: ServerArgs):
207211
"""
208212
app = FastAPI(lifespan=lifespan)
209213

214+
if server_args.enable_metrics:
215+
add_prometheus_middleware(app)
216+
add_prometheus_track_response_middleware(app)
217+
210218
app.include_router(health_router)
211219
app.include_router(vertex_router)
212220

python/sglang/multimodal_gen/runtime/launch_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
set_global_server_args,
1818
)
1919
from sglang.multimodal_gen.runtime.utils.logging_utils import configure_logger, logger
20+
from sglang.srt.utils import set_prometheus_multiproc_dir
2021

2122

2223
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
@@ -68,6 +69,9 @@ def launch_server(server_args: ServerArgs, launch_http_server: bool = True):
6869
# Start a new server with multiple worker processes
6970
logger.info("Starting server...")
7071

72+
if server_args.enable_metrics:
73+
set_prometheus_multiproc_dir()
74+
7175
num_gpus = server_args.num_gpus
7276
processes = []
7377

python/sglang/multimodal_gen/runtime/managers/gpu_worker.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@
4343
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import OutputBatch
4444
from sglang.multimodal_gen.runtime.platforms import current_platform
4545
from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs
46-
from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch
46+
from sglang.multimodal_gen.runtime.utils.common import (
47+
get_diffusion_metrics_collector,
48+
set_cuda_arch,
49+
)
4750
from sglang.multimodal_gen.runtime.utils.layerwise_offload import (
4851
OffloadableDiTMixin,
4952
iter_materialized_weights,
@@ -79,6 +82,11 @@ def __init__(
7982
# FIXME: should we use tcp as distribute init method?
8083
self.server_args = server_args
8184
self.pipeline: ComposedPipelineBase = None
85+
self.metrics_collector = (
86+
get_diffusion_metrics_collector(server_args)
87+
if server_args.enable_metrics and rank == 0
88+
else None
89+
)
8290

8391
self.init_device_and_model()
8492
self.sp_group = get_sp_group()
@@ -89,6 +97,8 @@ def __init__(
8997
self.cfg_group = get_cfg_group()
9098
self.cfg_cpu_group = self.cfg_group.cpu_group
9199

100+
self._update_lora_metrics()
101+
92102
def init_device_and_model(self) -> None:
93103
"""Initialize the device and load the model."""
94104
torch.get_device_module().set_device(self.local_rank)
@@ -199,7 +209,7 @@ def do_mem_analysis(self, output_batch: OutputBatch):
199209
logger.info(
200210
f"Peak GPU memory: {peak_reserved_gb:.2f} GB, "
201211
f"Peak allocated: {peak_allocated_gb:.2f} GB, "
202-
f"Memory pool overhead: {pool_overhead_gb:.2f} GB ({pool_overhead_gb/peak_reserved_gb*100:.1f}%), "
212+
f"Memory pool overhead: {pool_overhead_gb:.2f} GB ({pool_overhead_gb / peak_reserved_gb * 100:.1f}%), "
203213
f"Remaining GPU memory at peak: {remaining_gpu_mem_gb:.2f} GB. "
204214
f"Components that could stay resident (based on the last request workload): {can_stay_resident}. "
205215
f"Related offload server args to disable: {suggested_args_str}"
@@ -212,12 +222,12 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
212222
assert self.pipeline is not None
213223
req = batch[0]
214224
output_batch = None
225+
status = "success"
226+
start_time = time.monotonic()
215227
try:
216228
if self.rank == 0:
217229
torch.get_device_module().reset_peak_memory_stats()
218230

219-
start_time = time.monotonic()
220-
221231
# capture memory baseline before forward
222232
if self.rank == 0 and req.metrics:
223233
baseline_snapshot = capture_memory_snapshot()
@@ -274,7 +284,10 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
274284
# Avoid logging warmup perf records that share the same request_id.
275285
if not req.is_warmup:
276286
PerformanceLogger.log_request_summary(metrics=output_batch.metrics)
287+
if output_batch is not None and output_batch.error is not None:
288+
status = "error"
277289
except Exception as e:
290+
status = "error"
278291
logger.error(
279292
f"Error executing request {req.request_id}: {e}", exc_info=True
280293
)
@@ -283,8 +296,25 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
283296
if output_batch is None:
284297
output_batch = OutputBatch()
285298
output_batch.error = f"Error executing request {req.request_id}: {e}"
299+
finally:
300+
if self.metrics_collector is not None:
301+
self.metrics_collector.observe_request(
302+
status=status,
303+
is_warmup=req.is_warmup,
304+
latency_s=time.monotonic() - start_time,
305+
)
286306
return output_batch
287307

308+
def _update_lora_metrics(self):
309+
if self.metrics_collector is None:
310+
return
311+
312+
if not isinstance(self.pipeline, LoRAPipeline):
313+
self.metrics_collector.clear_lora_status()
314+
return
315+
316+
self.metrics_collector.update_lora_status(self.pipeline.get_lora_status())
317+
288318
def get_can_stay_resident_components(
289319
self, remaining_gpu_mem_gb: float
290320
) -> List[str]:
@@ -339,6 +369,7 @@ def set_lora(
339369
if not isinstance(self.pipeline, LoRAPipeline):
340370
return OutputBatch(error="Lora is not enabled")
341371
self.pipeline.set_lora(lora_nickname, lora_path, target, strength)
372+
self._update_lora_metrics()
342373
return OutputBatch()
343374

344375
def merge_lora_weights(
@@ -354,6 +385,7 @@ def merge_lora_weights(
354385
if not isinstance(self.pipeline, LoRAPipeline):
355386
return OutputBatch(error="Lora is not enabled")
356387
self.pipeline.merge_lora_weights(target, strength)
388+
self._update_lora_metrics()
357389
return OutputBatch()
358390

359391
def unmerge_lora_weights(self, target: str = "all") -> OutputBatch:
@@ -366,6 +398,7 @@ def unmerge_lora_weights(self, target: str = "all") -> OutputBatch:
366398
if not isinstance(self.pipeline, LoRAPipeline):
367399
return OutputBatch(error="Lora is not enabled")
368400
self.pipeline.unmerge_lora_weights(target)
401+
self._update_lora_metrics()
369402
return OutputBatch()
370403

371404
def list_loras(self) -> OutputBatch:

python/sglang/multimodal_gen/runtime/managers/scheduler.py

Lines changed: 106 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import os
66
import pickle
7+
import time
78
from collections import deque
89
from copy import deepcopy
910
from typing import Any, List
@@ -34,7 +35,10 @@
3435
ServerArgs,
3536
set_global_server_args,
3637
)
37-
from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket
38+
from sglang.multimodal_gen.runtime.utils.common import (
39+
get_diffusion_metrics_collector,
40+
get_zmq_socket,
41+
)
3842
from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj
3943
from sglang.multimodal_gen.runtime.utils.logging_utils import GREEN, RESET, init_logger
4044

@@ -101,6 +105,8 @@ def __init__(
101105

102106
# FIFO, new reqs are appended
103107
self.waiting_queue: deque[tuple[bytes, Req]] = deque()
108+
self._generation_waiting_count = 0
109+
self._generation_enqueue_timestamps: dict[int, float] = {}
104110

105111
# whether we've send the necessary warmup reqs
106112
self.warmed_up = False
@@ -110,6 +116,15 @@ def __init__(
110116

111117
self.prepare_server_warmup_reqs()
112118

119+
self.metrics_collector = (
120+
get_diffusion_metrics_collector(server_args)
121+
if server_args.enable_metrics and gpu_id == 0
122+
else None
123+
)
124+
if self.metrics_collector is not None:
125+
self.metrics_collector.set_queue_depth(self._generation_waiting_count)
126+
self.metrics_collector.set_running_reqs(0)
127+
113128
# Maximum consecutive errors before terminating the event loop
114129
self._max_consecutive_errors = 3
115130
self._consecutive_error_count = 0
@@ -187,9 +202,32 @@ def get_next_batch_to_run(self) -> list[tuple[bytes, Req]] | None:
187202

188203
# pop the first (earliest)
189204
item = self.waiting_queue.popleft()
205+
self._on_req_dequeued(item[1])
206+
if self.metrics_collector is not None:
207+
self.metrics_collector.set_queue_depth(self._generation_waiting_count)
190208

191209
return [item]
192210

211+
def _on_req_enqueued(self, req: Any) -> None:
212+
if not isinstance(req, Req):
213+
return
214+
self._generation_waiting_count += 1
215+
self._generation_enqueue_timestamps[id(req)] = time.monotonic()
216+
217+
def _on_req_dequeued(self, req: Any) -> None:
218+
if not isinstance(req, Req):
219+
return
220+
if self._generation_waiting_count > 0:
221+
self._generation_waiting_count -= 1
222+
enqueue_ts = self._generation_enqueue_timestamps.pop(id(req), None)
223+
if enqueue_ts is not None and self.metrics_collector is not None:
224+
self.metrics_collector.observe_queue_time(time.monotonic() - enqueue_ts)
225+
226+
def _enqueue_received_reqs(self, new_reqs: list[tuple[bytes, Any]]) -> None:
227+
self.waiting_queue.extend(new_reqs)
228+
for _, req in new_reqs:
229+
self._on_req_enqueued(req)
230+
193231
def prepare_server_warmup_reqs(self):
194232
if (
195233
self.server_args.warmup
@@ -235,6 +273,7 @@ def prepare_server_warmup_reqs(self):
235273
)
236274
req.set_as_warmup()
237275
self.waiting_queue.append((None, req))
276+
self._on_req_enqueued(req)
238277
# if server is warmed-up, set this flag to avoid req-based warmup
239278
self.warmed_up = True
240279

@@ -334,7 +373,11 @@ def event_loop(self) -> None:
334373
try:
335374
new_reqs = self.recv_reqs()
336375
new_reqs = self.process_received_reqs_with_req_based_warmup(new_reqs)
337-
self.waiting_queue.extend(new_reqs)
376+
self._enqueue_received_reqs(new_reqs)
377+
if self.metrics_collector is not None:
378+
self.metrics_collector.set_queue_depth(
379+
self._generation_waiting_count
380+
)
338381
# Reset error count on success
339382
self._consecutive_error_count = 0
340383
except Exception as e:
@@ -362,60 +405,72 @@ def event_loop(self) -> None:
362405

363406
identities = [item[0] for item in items]
364407
reqs = [item[1] for item in items]
408+
generation_running_reqs = sum(1 for req in reqs if isinstance(req, Req))
409+
if self.metrics_collector is not None:
410+
self.metrics_collector.set_running_reqs(generation_running_reqs)
365411

366412
try:
367-
processed_req = reqs[0]
368-
handler = self.request_handlers.get(type(processed_req))
369-
if handler:
370-
output_batch = handler(reqs)
371-
else:
372-
output_batch = OutputBatch(
373-
error=f"Unknown request type: {type(processed_req)}"
413+
try:
414+
processed_req = reqs[0]
415+
handler = self.request_handlers.get(type(processed_req))
416+
if handler:
417+
output_batch = handler(reqs)
418+
else:
419+
output_batch = OutputBatch(
420+
error=f"Unknown request type: {type(processed_req)}"
421+
)
422+
except Exception as e:
423+
logger.error(
424+
f"Error executing request in scheduler event loop: {e}",
425+
exc_info=True,
426+
)
427+
# Determine appropriate error response format
428+
output_batch = (
429+
OutputBatch(error=str(e))
430+
if reqs and isinstance(reqs[0], Req)
431+
else OutputBatch(error=str(e))
374432
)
375-
except Exception as e:
376-
logger.error(
377-
f"Error executing request in scheduler event loop: {e}",
378-
exc_info=True,
379-
)
380-
# Determine appropriate error response format
381-
output_batch = (
382-
OutputBatch(error=str(e))
383-
if reqs and isinstance(reqs[0], Req)
384-
else OutputBatch(error=str(e))
385-
)
386433

387-
# 3. return results
388-
try:
389-
# log warmup info
390-
is_warmup = (
391-
processed_req.is_warmup if isinstance(processed_req, Req) else False
392-
)
393-
if is_warmup:
394-
if output_batch.error is None:
395-
if self._warmup_total > 0:
396-
logger.info(
397-
f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processed in {GREEN}%.2f{RESET} seconds",
398-
output_batch.metrics.total_duration_s,
399-
)
400-
else:
401-
logger.info(
402-
f"Warmup req processed in {GREEN}%.2f{RESET} seconds",
403-
output_batch.metrics.total_duration_s,
404-
)
405-
else:
406-
if self._warmup_total > 0:
407-
logger.info(
408-
f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processing failed"
409-
)
434+
# 3. return results
435+
try:
436+
# log warmup info
437+
is_warmup = (
438+
processed_req.is_warmup
439+
if isinstance(processed_req, Req)
440+
else False
441+
)
442+
if is_warmup:
443+
if output_batch.error is None:
444+
if self._warmup_total > 0:
445+
logger.info(
446+
f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processed in {GREEN}%.2f{RESET} seconds",
447+
output_batch.metrics.total_duration_s,
448+
)
449+
else:
450+
logger.info(
451+
f"Warmup req processed in {GREEN}%.2f{RESET} seconds",
452+
output_batch.metrics.total_duration_s,
453+
)
410454
else:
411-
logger.info(f"Warmup req processing failed")
412-
413-
# TODO: Support sending back to multiple identities if batched
414-
self.return_result(output_batch, identities[0], is_warmup=is_warmup)
415-
except zmq.ZMQError as e:
416-
# Reply failed; log and keep loop alive to accept future requests
417-
logger.error(f"ZMQ error sending reply: {e}")
418-
continue
455+
if self._warmup_total > 0:
456+
logger.info(
457+
f"Warmup req ({self._warmup_processed}/{self._warmup_total}) processing failed"
458+
)
459+
else:
460+
logger.info(f"Warmup req processing failed")
461+
462+
# TODO: Support sending back to multiple identities if batched
463+
self.return_result(output_batch, identities[0], is_warmup=is_warmup)
464+
except zmq.ZMQError as e:
465+
# Reply failed; log and keep loop alive to accept future requests
466+
logger.error(f"ZMQ error sending reply: {e}")
467+
continue
468+
finally:
469+
if self.metrics_collector is not None:
470+
self.metrics_collector.set_running_reqs(0)
471+
self.metrics_collector.set_queue_depth(
472+
self._generation_waiting_count
473+
)
419474

420475
if self.receiver is not None:
421476
self.receiver.close()

python/sglang/multimodal_gen/runtime/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ class ServerArgs:
327327
# http server endpoint config
328328
host: str | None = "127.0.0.1"
329329
port: int | None = 30000
330+
enable_metrics: bool = False
330331

331332
# TODO: webui and their endpoint, check if webui_port is available.
332333
webui: bool = False
@@ -858,6 +859,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
858859
default=ServerArgs.port,
859860
help="Port for the HTTP API server.",
860861
)
862+
parser.add_argument(
863+
"--enable-metrics",
864+
action=StoreBoolean,
865+
default=ServerArgs.enable_metrics,
866+
help="Enable Prometheus metrics endpoint at /metrics.",
867+
)
861868
parser.add_argument(
862869
"--webui",
863870
action=StoreBoolean,

0 commit comments

Comments
 (0)