Skip to content

Commit 27e8ffe

Browse files
authored
[1/N] DP-refactor: move dp balance code into scheduler's mixin class (sgl-project#10004)
1 parent 4dbb34f commit 27e8ffe

File tree

2 files changed

+116
-106
lines changed

2 files changed

+116
-106
lines changed

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def __init__(
500500
# Init metrics stats
501501
self.init_metrics(tp_rank, pp_rank, dp_rank)
502502
self.init_kv_events(server_args.kv_events_config)
503+
self.init_dp_balance(dp_balance_meta)
503504

504505
# Init disaggregation
505506
self.disaggregation_mode = DisaggregationMode(
@@ -545,15 +546,6 @@ def __init__(
545546
]
546547
)
547548

548-
self.balance_meta = dp_balance_meta
549-
if (
550-
server_args.enable_dp_attention
551-
and server_args.load_balance_method == "minimum_tokens"
552-
):
553-
assert dp_balance_meta is not None
554-
555-
self.recv_dp_balance_id_this_term = []
556-
557549
def init_tokenizer(self):
558550
server_args = self.server_args
559551
self.is_generation = self.model_config.is_generation
@@ -1126,11 +1118,7 @@ def handle_generate_request(
11261118
self,
11271119
recv_req: TokenizedGenerateReqInput,
11281120
):
1129-
if (
1130-
self.server_args.enable_dp_attention
1131-
and self.server_args.load_balance_method == "minimum_tokens"
1132-
):
1133-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1121+
self.maybe_update_dp_balance_data(recv_req)
11341122

11351123
# Create a new request
11361124
if (
@@ -1568,11 +1556,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
15681556

15691557
# Handle DP attention
15701558
if need_dp_attn_preparation:
1571-
if (
1572-
self.server_args.load_balance_method == "minimum_tokens"
1573-
and self.forward_ct % 40 == 0
1574-
):
1575-
self.handle_dp_balance_data(ret)
1559+
self.maybe_handle_dp_balance_data()
15761560
ret = self.prepare_mlp_sync_batch(ret)
15771561

15781562
return ret
@@ -1897,86 +1881,6 @@ def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
18971881
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
18981882
)
18991883

1900-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1901-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1902-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1903-
recv_list = self.recv_dp_balance_id_this_term
1904-
assert len(recv_list) <= 511, (
1905-
"The number of requests received this round is too large. "
1906-
"Please increase gather_tensor_size and onfly_info_size."
1907-
)
1908-
# The maximum size of the tensor used for gathering data from all workers.
1909-
gather_tensor_size = 512
1910-
1911-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1912-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1913-
recv_tensor[0] = holding_tokens_list
1914-
recv_tensor[1] = len(
1915-
recv_list
1916-
) # The first element is the length of the list.
1917-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1918-
recv_list, dtype=torch.int32
1919-
)
1920-
1921-
if self.tp_rank == 0:
1922-
gathered_list = [
1923-
torch.zeros(gather_tensor_size, dtype=torch.int32)
1924-
for _ in range(self.balance_meta.num_workers)
1925-
]
1926-
else:
1927-
gathered_list = None
1928-
1929-
torch.distributed.gather(
1930-
recv_tensor, gathered_list, group=self.tp_cpu_group
1931-
)
1932-
1933-
gathered_id_list_per_worker = None
1934-
if self.tp_rank == 0:
1935-
gathered_id_list_per_worker = []
1936-
holding_tokens_list = []
1937-
for tensor in gathered_list:
1938-
holding_tokens_list.append(tensor[0].item())
1939-
list_length = tensor[1].item()
1940-
gathered_id_list_per_worker.append(
1941-
tensor[2 : list_length + 2].tolist()
1942-
)
1943-
1944-
return gathered_id_list_per_worker, holding_tokens_list
1945-
1946-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1947-
meta = self.balance_meta
1948-
1949-
with meta.mutex:
1950-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1951-
assert len(new_recv_rid_lists) == len(
1952-
onfly_list
1953-
), "num_worker not equal"
1954-
# 1.Check if the rid received by each worker this round is present in onfly.
1955-
# If it is, remove the corresponding onfly item.
1956-
worker_id = 0
1957-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1958-
for new_recv_rid in new_recv_rids:
1959-
assert (
1960-
new_recv_rid in on_fly_reqs
1961-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1962-
del on_fly_reqs[new_recv_rid]
1963-
worker_id += 1
1964-
# 2. Atomically write local_tokens and onfly into shm under the mutex
1965-
meta.set_shared_onfly_info(onfly_list)
1966-
meta.set_shared_local_tokens(local_tokens)
1967-
1968-
holding_tokens = self.get_load()
1969-
1970-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1971-
holding_tokens
1972-
)
1973-
1974-
self.recv_dp_balance_id_this_term.clear()
1975-
if self.tp_rank == 0: # only first worker write info
1976-
write_shared_dp_balance_info(
1977-
new_recv_dp_balance_id_list, holding_token_list
1978-
)
1979-
19801884
@staticmethod
19811885
def prepare_mlp_sync_batch_raw(
19821886
local_batch: ScheduleBatch,

python/sglang/srt/managers/scheduler_metrics_mixin.py

Lines changed: 113 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1+
from __future__ import annotations
2+
13
import logging
24
import time
35
from collections import defaultdict
4-
from typing import List, Optional
6+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
7+
8+
import torch
59

610
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
711
from sglang.srt.disaggregation.utils import DisaggregationMode
12+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
813
from sglang.srt.managers.schedule_policy import PrefillAdder
914
from sglang.srt.managers.scheduler import Req, ScheduleBatch
15+
from sglang.srt.managers.utils import DPBalanceMeta
1016
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
1117
from sglang.srt.utils import get_bool_env_var
1218

19+
if TYPE_CHECKING:
20+
from sglang.srt.managers.scheduler import Scheduler
21+
1322
logger = logging.getLogger(__name__)
1423

1524
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@@ -28,7 +37,9 @@ def __init__(self):
2837

2938

3039
class SchedulerMetricsMixin:
31-
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
40+
def init_metrics(
41+
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
42+
):
3243
self.last_gen_throughput: float = 0.0
3344
self.last_input_throughput: float = 0.0
3445
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -50,14 +61,24 @@ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
5061
labels["dp_rank"] = dp_rank
5162
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
5263

53-
def init_kv_events(self, kv_events_config: Optional[str]):
64+
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65+
self.balance_meta = dp_balance_meta
66+
if (
67+
self.server_args.enable_dp_attention
68+
and self.server_args.load_balance_method == "minimum_tokens"
69+
):
70+
assert dp_balance_meta is not None
71+
72+
self.recv_dp_balance_id_this_term = []
73+
74+
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
5475
if self.enable_kv_cache_events:
5576
self.kv_event_publisher = EventPublisherFactory.create(
5677
kv_events_config, self.attn_dp_rank
5778
)
5879

5980
def log_prefill_stats(
60-
self,
81+
self: Scheduler,
6182
adder: PrefillAdder,
6283
can_run_list: List[Req],
6384
running_bs: int,
@@ -138,7 +159,7 @@ def log_prefill_stats(
138159
self._publish_kv_events()
139160

140161
def log_decode_stats(
141-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
162+
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
142163
):
143164
batch = running_batch or self.running_batch
144165

@@ -220,7 +241,7 @@ def log_decode_stats(
220241
self._emit_kv_metrics()
221242
self._publish_kv_events()
222243

223-
def _emit_kv_metrics(self):
244+
def _emit_kv_metrics(self: Scheduler):
224245
kv_metrics = KvMetrics()
225246
kv_metrics.request_active_slots = self.stats.num_running_reqs
226247
kv_metrics.request_total_slots = self.max_running_requests
@@ -236,9 +257,94 @@ def _emit_kv_metrics(self):
236257
if not self.send_metrics_from_scheduler.closed:
237258
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
238259

239-
def _publish_kv_events(self):
260+
def _publish_kv_events(self: Scheduler):
240261
if self.enable_kv_cache_events:
241262
events = self.tree_cache.take_events()
242263
if events:
243264
batch = KVEventBatch(ts=time.time(), events=events)
244265
self.kv_event_publisher.publish(batch)
266+
267+
def maybe_update_dp_balance_data(
268+
self: Scheduler, recv_req: TokenizedGenerateReqInput
269+
):
270+
if (
271+
self.server_args.enable_dp_attention
272+
and self.server_args.load_balance_method == "minimum_tokens"
273+
):
274+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
275+
276+
def maybe_handle_dp_balance_data(self: Scheduler):
277+
if (
278+
self.server_args.load_balance_method == "minimum_tokens"
279+
and self.forward_ct % 40 == 0
280+
):
281+
holding_tokens = self.get_load()
282+
283+
new_recv_dp_balance_id_list, holding_token_list = (
284+
self.gather_dp_balance_info(holding_tokens)
285+
)
286+
287+
self.recv_dp_balance_id_this_term.clear()
288+
if self.tp_rank == 0: # only first worker write info
289+
self.write_shared_dp_balance_info(
290+
new_recv_dp_balance_id_list, holding_token_list
291+
)
292+
293+
def gather_dp_balance_info(
294+
self: Scheduler, holding_tokens_list
295+
) -> Union[None, List[List[int]]]:
296+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297+
recv_list = self.recv_dp_balance_id_this_term
298+
assert len(recv_list) <= 511, (
299+
"The number of requests received this round is too large. "
300+
"Please increase gather_tensor_size and onfly_info_size."
301+
)
302+
# The maximum size of the tensor used for gathering data from all workers.
303+
gather_tensor_size = 512
304+
305+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
307+
recv_tensor[0] = holding_tokens_list
308+
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
309+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
310+
311+
if self.tp_rank == 0:
312+
gathered_list = [
313+
torch.zeros(gather_tensor_size, dtype=torch.int32)
314+
for _ in range(self.balance_meta.num_workers)
315+
]
316+
else:
317+
gathered_list = None
318+
319+
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
320+
321+
gathered_id_list_per_worker = None
322+
if self.tp_rank == 0:
323+
gathered_id_list_per_worker = []
324+
holding_tokens_list = []
325+
for tensor in gathered_list:
326+
holding_tokens_list.append(tensor[0].item())
327+
list_length = tensor[1].item()
328+
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
329+
330+
return gathered_id_list_per_worker, holding_tokens_list
331+
332+
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
333+
meta = self.balance_meta
334+
335+
with meta.mutex:
336+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
337+
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
338+
# 1.Check if the rid received by each worker this round is present in onfly.
339+
# If it is, remove the corresponding onfly item.
340+
worker_id = 0
341+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
342+
for new_recv_rid in new_recv_rids:
343+
assert (
344+
new_recv_rid in on_fly_reqs
345+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
346+
del on_fly_reqs[new_recv_rid]
347+
worker_id += 1
348+
# 2. Atomically write local_tokens and onfly into shm under the mutex
349+
meta.set_shared_onfly_info(onfly_list)
350+
meta.set_shared_local_tokens(local_tokens)

0 commit comments

Comments
 (0)