Skip to content

Commit 6da7daa

Browse files
committed
enable simple tune and change forward_batch, sampling_metadata, forward_metadata iniitilization position
1 parent 8655866 commit 6da7daa

6 files changed

Lines changed: 342 additions & 248 deletions

File tree

python/sgl_jax/srt/kernels/ragged_paged_attention/tuned_block_sizes.py

Lines changed: 224 additions & 221 deletions
Large diffs are not rendered by default.

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191

192192
# Each decode stage's output ids
193193
self.output_ids = []
194+
# self.next_token_ids_device: jax.Array = None # store it and use device_get to get when need it
194195
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
195196
self.fill_ids = []
196197

@@ -407,6 +408,7 @@ def init_incremental_detokenize(self):
407408
self.surr_offset = max(self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
408409

409410
all_ids = self.origin_input_ids_unpadded + self.output_ids
411+
# all_ids = self.origin_input_ids_unpadded + self.output_ids[:-1] if self.next_token_ids_device else self.output_ids
410412
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
411413

412414
def check_finished(self, new_accepted_len: int = 1):
@@ -1031,6 +1033,24 @@ def prepare_for_decode(self):
10311033
if self.sampling_info.penalizer_orchestrator.is_required:
10321034
if self.enable_overlap:
10331035
# TODO: this can be slow, optimize this.
1036+
# tmp=[]
1037+
# for req in self.reqs:
1038+
# if req.next_token_ids_device:
1039+
# output_id=jax.device_get(req.next_token_ids_device).tolist()[0]
1040+
# req.next_token_ids_device=None
1041+
# req.output_ids[-1]=output_id
1042+
# print(f"[next_token_ids_device] {req.output_ids=}",flush=True)
1043+
# elif len(req.output_ids):
1044+
# output_id = req.output_ids[-1]
1045+
# print(f"[len(req.output_ids):] {req.output_ids=}",flush=True)
1046+
# else:
1047+
# output_id = req.origin_input_ids[-1]
1048+
# print(f"[other]] {req.output_ids=}",flush=True)
1049+
# tmp.append(output_id)
1050+
1051+
# delayed_output_ids = np.array(tmp,dtype=np.int64)
1052+
# print(f"[prepare_for_decode] {tmp=}",flush=True)
1053+
10341054
delayed_output_ids = np.array(
10351055
[
10361056
(req.output_ids[-1] if len(req.output_ids) else req.origin_input_ids[-1])
@@ -1042,8 +1062,25 @@ def prepare_for_decode(self):
10421062
else:
10431063
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.output_ids)
10441064

1065+
# output_ids = self.output_ids
1066+
# if self.enable_overlap:
1067+
# print(f"=======overlap========",flush=True)
1068+
# valid_output_ids = []
1069+
# for req in self.reqs:
1070+
# print(f"[for req] {req.next_token_ids_device=}, {req.output_ids=}",flush=True)
1071+
# if req.next_token_ids_device:
1072+
# output_id=jax.device_get(req.next_token_ids_device).tolist()[0]
1073+
# req.next_token_ids_device=None
1074+
# req.output_ids[-1]=output_id
1075+
# valid_output_ids.append(output_id)
1076+
# print(f"[next_token_ids_device 1] {req.output_ids=}",flush=True)
1077+
# output_ids = np.concat(valid_output_ids,dtype=np.int32)
1078+
10451079
# Update fields
10461080
self.input_ids = self.output_ids
1081+
# self.input_ids = output_ids
1082+
1083+
# print(f"[prepare_for_decode] {self.input_ids=}",flush=True)
10471084

10481085
self.output_ids = None
10491086

@@ -1215,7 +1252,9 @@ def get_model_worker_batch(
12151252
seq_lens_cpu = self.seq_lens
12161253
real_bs = len(seq_lens_cpu)
12171254
req_pool_indices_cpu = self.req_pool_indices
1218-
token_indices_with_all_reqs = self.req_to_token_pool.req_to_token[self.req_pool_indices]
1255+
token_indices_with_all_reqs = self.req_to_token_pool.req_to_token[
1256+
self.req_pool_indices
1257+
] # cost in pathways, 23ms
12191258

12201259
# padding seq
12211260
# extend & decode: input_ids, positions, out_cache_loc, cache_loc
@@ -1313,6 +1352,8 @@ def get_model_worker_batch(
13131352

13141353
# Fill the array efficiently
13151354
offset = 0
1355+
######################### cost in Pathways 10ms#####################
1356+
#####concurrecny=256,tp=4,page_size=256,max_running_requests=256
13161357
for i, (seq_idx, seq_len, aligned_len) in enumerate(
13171358
zip(valid_indices, valid_seq_lens, aligned_lengths)
13181359
):
@@ -1322,6 +1363,7 @@ def get_model_worker_batch(
13221363
]
13231364
# Padding is already zero from initialization
13241365
offset += aligned_len
1366+
######################### cost in Pathways#####################
13251367

13261368
offset = np.sum(seq_lens_cpu[seq_lens_cpu > 0]) if len(seq_lens_cpu) > 0 else 0
13271369

@@ -1335,6 +1377,7 @@ def get_model_worker_batch(
13351377
if len(cache_loc_flat) < total_cache_loc_size:
13361378
cache_loc_cpu[len(cache_loc_flat) :] = 0
13371379

1380+
####################cost in Pathways 22ms######################
13381381
if bs_padding_size > 0:
13391382
invalid_req_pool_indices = np.array(
13401383
[-1] * bs_padding_size, dtype=req_pool_indices_cpu.dtype
@@ -1365,6 +1408,8 @@ def get_model_worker_batch(
13651408
[extend_logprob_start_lens, invalid_extend_logprob_start_lens], axis=0
13661409
)
13671410

1411+
############################################################################
1412+
13681413
sampling_info = self.sampling_info
13691414
if self.sampling_info:
13701415
new_temperatures = np.concatenate(

python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,40 @@ def process_batch_result_prefill(
8989
hidden_state_offset = 0
9090
# Check finish conditions
9191
logprob_pt = 0
92+
# print(f"[process_batch_result_prefill] {len(batch.reqs)=}",flush=True)
9293
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
94+
# for i, req in enumerate(batch.reqs):
9395
if req.is_retracted:
9496
continue
9597

9698
req.latest_bid = batch.bid
9799

100+
# if self.enable_overlap:
101+
# next_token_id=next_token_ids[i]
102+
98103
if self.is_mixed_chunk and self.enable_overlap and req.finished():
99104
j = len(batch.out_cache_loc) - len(batch.reqs) + i
100105
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
101106
continue
102107

108+
# print(f"[process_batch_result_prefill] {req.is_chunked=}, {next_token_id=}",flush=True)
103109
if req.is_chunked <= 0:
104110
# req output_ids are set here
105111
req.output_ids.append(next_token_id)
112+
# if self.enable_overlap:
113+
# req.next_token_ids_device=next_token_id
114+
# req.output_ids.append(0)
115+
# else:
116+
# req.output_ids.append(next_token_id)
117+
118+
# print(f"============[process_batch_result_prefill] {req.output_ids=}, {req.next_token_ids_device=}",flush=True)
119+
106120
req.check_finished()
107121

108122
if req.finished():
123+
# if req.next_token_ids_device:
124+
# req.output_ids[-1] = jax.device_get(req.next_token_ids_device).tolist()[0]
125+
# req.next_token_ids_device=None
109126
self.maybe_collect_routed_experts(req)
110127
if precision_tracer.get_trace_active():
111128
precision_tracer.set_request_status_to_completed(req.rid)
@@ -275,6 +292,7 @@ def process_batch_result_decode(
275292

276293
# Check finish condition
277294
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
295+
# for i, req in enumerate(batch.reqs):
278296
req: Req
279297
if req.is_retracted:
280298
continue
@@ -292,15 +310,25 @@ def process_batch_result_decode(
292310
self.token_to_kv_pool_allocator.free(indices_to_free)
293311
continue
294312

313+
# next_token_id = next_token_ids[i]
314+
295315
new_accepted_len = 1
296316
if batch.spec_algorithm is None or batch.spec_algorithm.is_none():
297317
req.output_ids.append(next_token_id)
318+
# if self.enable_overlap:
319+
# req.output_ids.append(0)
320+
# req.next_token_ids_device=next_token_id
321+
# else:
322+
# req.output_ids.append(next_token_id)
298323
elif self.spec_algorithm.is_eagle():
299324
req.output_ids.extend(next_token_id)
300325
new_accepted_len = len(next_token_id)
301326

302327
req.check_finished(new_accepted_len)
303328
if req.finished():
329+
# if req.next_token_ids_device:
330+
# req.output_ids[-1] = jax.device_get(req.next_token_ids_device).tolist()[0]
331+
# req.next_token_ids_device=None
304332
self.maybe_collect_routed_experts(req)
305333
if batch.spec_algorithm is not None and batch.spec_algorithm.is_eagle():
306334
cur_allocate_len = batch.spec_info.allocate_lens[i]

python/sgl_jax/srt/managers/scheduler_profiler_mixing.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@ def start_profile(
3838
if output_dir is None:
3939
output_dir = os.getenv("SGLANG_JAX_PROFILER_DIR", "/tmp")
4040

41-
# check permission for output_dir
42-
tmp_output_dir = output_dir
43-
while not os.path.exists(tmp_output_dir):
44-
tmp_output_dir = os.path.dirname(tmp_output_dir)
45-
if not os.access(tmp_output_dir, os.W_OK):
46-
return ProfileReqOutput(
47-
success=False,
48-
message=f"no permission to write the {output_dir}",
49-
)
41+
if not output_dir.startswith("gs"):
42+
# gs prefix is used in Pathways, skip check in Pathways
43+
# check permission for output_dir
44+
tmp_output_dir = output_dir
45+
while not os.path.exists(tmp_output_dir):
46+
tmp_output_dir = os.path.dirname(tmp_output_dir)
47+
if not os.access(tmp_output_dir, os.W_OK):
48+
return ProfileReqOutput(
49+
success=False,
50+
message=f"no permission to write the {output_dir}",
51+
)
5052

5153
self.profiler_output_dir = output_dir
5254
self.profile_id = profile_id

python/sgl_jax/srt/managers/tp_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def forward_batch_generation(
606606
sampling_metadata,
607607
)
608608
cache_miss_count += count()
609+
next_token_ids_device = jax.copy_to_host_async(next_token_ids_device)
609610
if model_worker_batch.return_output_logprob_only:
610611
logprobs = self.model_runner.compute_logprobs(token_logprobs, next_token_ids_device)
611612
logits_output.next_token_logprobs = logprobs[: model_worker_batch.real_bs]

python/sgl_jax/srt/managers/tp_worker_overlap_thread.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,22 @@ def forward_thread_func_(self):
102102
if not model_worker_batch:
103103
break
104104

105+
# initialize forward_metadata, sampling_metadata and forward_batch for DeepSeek-R1-Distill-Qwen-1.5B
106+
if sampling_metadata is None:
107+
sampling_metadata = SamplingMetadata.from_model_worker_batch(
108+
model_worker_batch,
109+
len(model_worker_batch.seq_lens) - model_worker_batch.real_bs,
110+
self.mesh,
111+
self.worker.model_config.vocab_size,
112+
)
113+
114+
forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata(
115+
model_worker_batch
116+
)
117+
model_worker_batch.forward_batch = ForwardBatch.init_new(
118+
model_worker_batch, self.worker.get_model_runner()
119+
)
120+
105121
# Resolve future tokens in the input
106122
input_ids = model_worker_batch.forward_batch.input_ids
107123
model_worker_batch.forward_batch.input_ids = resolve_future_token_ids(
@@ -118,7 +134,6 @@ def forward_thread_func_(self):
118134
forward_metadata=forward_metadata,
119135
)
120136
)
121-
122137
# Update the future token ids map
123138
self.future_token_ids_map = set_future_token_ids(
124139
self.future_token_ids_map,
@@ -143,7 +158,7 @@ def resolve_last_batch_result(self, launch_done: threading.Event | None = None):
143158
).tolist()
144159
if logits_output.hidden_states is not None:
145160
logits_output.hidden_states = jax.device_get(logits_output.hidden_states)
146-
next_token_ids = jax.device_get(next_token_ids).tolist()
161+
# next_token_ids = jax.device_get(next_token_ids).tolist()
147162

148163
if launch_done is not None:
149164
launch_done.wait()
@@ -164,33 +179,33 @@ def forward_batch_generation(
164179
penalizer_orchestrator=None,
165180
)
166181

167-
if sampling_metadata is None:
168-
sampling_metadata = SamplingMetadata.from_model_worker_batch(
169-
model_worker_batch,
170-
len(model_worker_batch.seq_lens) - model_worker_batch.real_bs,
171-
self.mesh,
172-
self.worker.model_config.vocab_size,
173-
)
182+
# if sampling_metadata is None:
183+
# sampling_metadata = SamplingMetadata.from_model_worker_batch(
184+
# model_worker_batch,
185+
# len(model_worker_batch.seq_lens) - model_worker_batch.real_bs,
186+
# self.mesh,
187+
# self.worker.model_config.vocab_size,
188+
# )
174189

175-
forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata(
176-
model_worker_batch
177-
)
190+
# forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata(
191+
# model_worker_batch
192+
# )
178193

179194
# Prepare LoRA batch if LoRA is enabled
180195
if self.worker.server_args.enable_lora:
181196
self.worker.prepare_lora_batch(model_worker_batch)
182197

183-
model_worker_batch.forward_batch = ForwardBatch.init_new(
184-
model_worker_batch, self.worker.get_model_runner()
185-
)
198+
# model_worker_batch.forward_batch = ForwardBatch.init_new(
199+
# model_worker_batch, self.worker.get_model_runner()
200+
# )
186201

187202
# Push a new batch to the queue (JAX handles synchronization automatically)
188203
self.input_queue.put(
189204
(
190205
model_worker_batch,
191206
self.future_token_ids_ct,
192-
sampling_metadata,
193-
forward_metadata,
207+
None, # sampling_metadata
208+
None, # forward_metadata
194209
)
195210
)
196211

0 commit comments

Comments
 (0)