1+ from __future__ import annotations
2+
13import logging
24import time
35from collections import defaultdict
4- from typing import List , Optional
6+ from typing import TYPE_CHECKING , Dict , List , Optional , Union
7+
8+ import torch
59
610from sglang .srt .disaggregation .kv_events import EventPublisherFactory , KVEventBatch
711from sglang .srt .disaggregation .utils import DisaggregationMode
12+ from sglang .srt .managers .io_struct import TokenizedGenerateReqInput
813from sglang .srt .managers .schedule_policy import PrefillAdder
914from sglang .srt .managers .scheduler import Req , ScheduleBatch
15+ from sglang .srt .managers .utils import DPBalanceMeta
1016from sglang .srt .metrics .collector import SchedulerMetricsCollector , SchedulerStats
1117from sglang .srt .utils import get_bool_env_var
1218
19+ if TYPE_CHECKING :
20+ from sglang .srt .managers .scheduler import Scheduler
21+
1322logger = logging .getLogger (__name__ )
1423
1524RECORD_STEP_TIME = get_bool_env_var ("SGLANG_RECORD_STEP_TIME" )
@@ -28,7 +37,9 @@ def __init__(self):
2837
2938
3039class SchedulerMetricsMixin :
31- def init_metrics (self , tp_rank : int , pp_rank : int , dp_rank : Optional [int ]):
40+ def init_metrics (
41+ self : Scheduler , tp_rank : int , pp_rank : int , dp_rank : Optional [int ]
42+ ):
3243 self .last_gen_throughput : float = 0.0
3344 self .last_input_throughput : float = 0.0
3445 self .step_time_dict = defaultdict (list ) # Dict[batch size -> step time]
@@ -50,14 +61,24 @@ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
5061 labels ["dp_rank" ] = dp_rank
5162 self .metrics_collector = SchedulerMetricsCollector (labels = labels )
5263
53- def init_kv_events (self , kv_events_config : Optional [str ]):
64+ def init_dp_balance (self : Scheduler , dp_balance_meta : Optional [DPBalanceMeta ]):
65+ self .balance_meta = dp_balance_meta
66+ if (
67+ self .server_args .enable_dp_attention
68+ and self .server_args .load_balance_method == "minimum_tokens"
69+ ):
70+ assert dp_balance_meta is not None
71+
72+ self .recv_dp_balance_id_this_term = []
73+
74+ def init_kv_events (self : Scheduler , kv_events_config : Optional [str ]):
5475 if self .enable_kv_cache_events :
5576 self .kv_event_publisher = EventPublisherFactory .create (
5677 kv_events_config , self .attn_dp_rank
5778 )
5879
5980 def log_prefill_stats (
60- self ,
81+ self : Scheduler ,
6182 adder : PrefillAdder ,
6283 can_run_list : List [Req ],
6384 running_bs : int ,
@@ -138,7 +159,7 @@ def log_prefill_stats(
138159 self ._publish_kv_events ()
139160
140161 def log_decode_stats (
141- self , can_run_cuda_graph : bool , running_batch : ScheduleBatch = None
162+ self : Scheduler , can_run_cuda_graph : bool , running_batch : ScheduleBatch = None
142163 ):
143164 batch = running_batch or self .running_batch
144165
@@ -220,7 +241,7 @@ def log_decode_stats(
220241 self ._emit_kv_metrics ()
221242 self ._publish_kv_events ()
222243
223- def _emit_kv_metrics (self ):
244+ def _emit_kv_metrics (self : Scheduler ):
224245 kv_metrics = KvMetrics ()
225246 kv_metrics .request_active_slots = self .stats .num_running_reqs
226247 kv_metrics .request_total_slots = self .max_running_requests
@@ -236,9 +257,94 @@ def _emit_kv_metrics(self):
236257 if not self .send_metrics_from_scheduler .closed :
237258 self .send_metrics_from_scheduler .send_pyobj (kv_metrics )
238259
239- def _publish_kv_events (self ):
260+ def _publish_kv_events (self : Scheduler ):
240261 if self .enable_kv_cache_events :
241262 events = self .tree_cache .take_events ()
242263 if events :
243264 batch = KVEventBatch (ts = time .time (), events = events )
244265 self .kv_event_publisher .publish (batch )
266+
267+ def maybe_update_dp_balance_data (
268+ self : Scheduler , recv_req : TokenizedGenerateReqInput
269+ ):
270+ if (
271+ self .server_args .enable_dp_attention
272+ and self .server_args .load_balance_method == "minimum_tokens"
273+ ):
274+ self .recv_dp_balance_id_this_term .append (recv_req .dp_balance_id )
275+
276+ def maybe_handle_dp_balance_data (self : Scheduler ):
277+ if (
278+ self .server_args .load_balance_method == "minimum_tokens"
279+ and self .forward_ct % 40 == 0
280+ ):
281+ holding_tokens = self .get_load ()
282+
283+ new_recv_dp_balance_id_list , holding_token_list = (
284+ self .gather_dp_balance_info (holding_tokens )
285+ )
286+
287+ self .recv_dp_balance_id_this_term .clear ()
288+ if self .tp_rank == 0 : # only first worker write info
289+ self .write_shared_dp_balance_info (
290+ new_recv_dp_balance_id_list , holding_token_list
291+ )
292+
293+ def gather_dp_balance_info (
294+ self : Scheduler , holding_tokens_list
295+ ) -> Union [None , List [List [int ]]]:
296+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297+ recv_list = self .recv_dp_balance_id_this_term
298+ assert len (recv_list ) <= 511 , (
299+ "The number of requests received this round is too large. "
300+ "Please increase gather_tensor_size and onfly_info_size."
301+ )
302+ # The maximum size of the tensor used for gathering data from all workers.
303+ gather_tensor_size = 512
304+
305+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306+ recv_tensor = torch .zeros (gather_tensor_size , dtype = torch .int32 )
307+ recv_tensor [0 ] = holding_tokens_list
308+ recv_tensor [1 ] = len (recv_list ) # The first element is the length of the list.
309+ recv_tensor [2 : len (recv_list ) + 2 ] = torch .tensor (recv_list , dtype = torch .int32 )
310+
311+ if self .tp_rank == 0 :
312+ gathered_list = [
313+ torch .zeros (gather_tensor_size , dtype = torch .int32 )
314+ for _ in range (self .balance_meta .num_workers )
315+ ]
316+ else :
317+ gathered_list = None
318+
319+ torch .distributed .gather (recv_tensor , gathered_list , group = self .tp_cpu_group )
320+
321+ gathered_id_list_per_worker = None
322+ if self .tp_rank == 0 :
323+ gathered_id_list_per_worker = []
324+ holding_tokens_list = []
325+ for tensor in gathered_list :
326+ holding_tokens_list .append (tensor [0 ].item ())
327+ list_length = tensor [1 ].item ()
328+ gathered_id_list_per_worker .append (tensor [2 : list_length + 2 ].tolist ())
329+
330+ return gathered_id_list_per_worker , holding_tokens_list
331+
332+ def write_shared_dp_balance_info (self : Scheduler , new_recv_rid_lists , local_tokens ):
333+ meta = self .balance_meta
334+
335+ with meta .mutex :
336+ onfly_list : List [Dict [int , int ]] = meta .get_shared_onfly ()
337+ assert len (new_recv_rid_lists ) == len (onfly_list ), "num_worker not equal"
338+ # 1.Check if the rid received by each worker this round is present in onfly.
339+ # If it is, remove the corresponding onfly item.
340+ worker_id = 0
341+ for new_recv_rids , on_fly_reqs in zip (new_recv_rid_lists , onfly_list ):
342+ for new_recv_rid in new_recv_rids :
343+ assert (
344+ new_recv_rid in on_fly_reqs
345+ ), f"{ new_recv_rid = } not in { worker_id = } { on_fly_reqs = } , data consistency is wrong"
346+ del on_fly_reqs [new_recv_rid ]
347+ worker_id += 1
348+ # 2. Atomically write local_tokens and onfly into shm under the mutex
349+ meta .set_shared_onfly_info (onfly_list )
350+ meta .set_shared_local_tokens (local_tokens )
0 commit comments