@@ -191,6 +191,7 @@ def __init__(
191191
192192 # Each decode stage's output ids
193193 self .output_ids = []
194+ # self.next_token_ids_device: jax.Array = None # store it and use device_get to get when need it
194195 # fill_ids = origin_input_ids + output_ids. Updated if chunked.
195196 self .fill_ids = []
196197
@@ -407,6 +408,7 @@ def init_incremental_detokenize(self):
407408 self .surr_offset = max (self .read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET , 0 )
408409
409410 all_ids = self .origin_input_ids_unpadded + self .output_ids
411+ # all_ids = self.origin_input_ids_unpadded + self.output_ids[:-1] if self.next_token_ids_device else self.output_ids
410412 return all_ids [self .surr_offset :], self .read_offset - self .surr_offset
411413
412414 def check_finished (self , new_accepted_len : int = 1 ):
@@ -1031,6 +1033,24 @@ def prepare_for_decode(self):
10311033 if self .sampling_info .penalizer_orchestrator .is_required :
10321034 if self .enable_overlap :
10331035 # TODO: this can be slow, optimize this.
1036+ # tmp=[]
1037+ # for req in self.reqs:
1038+ # if req.next_token_ids_device:
1039+ # output_id=jax.device_get(req.next_token_ids_device).tolist()[0]
1040+ # req.next_token_ids_device=None
1041+ # req.output_ids[-1]=output_id
1042+ # print(f"[next_token_ids_device] {req.output_ids=}",flush=True)
1043+ # elif len(req.output_ids):
1044+ # output_id = req.output_ids[-1]
1045+ # print(f"[len(req.output_ids):] {req.output_ids=}",flush=True)
1046+ # else:
1047+ # output_id = req.origin_input_ids[-1]
1048+ # print(f"[other]] {req.output_ids=}",flush=True)
1049+ # tmp.append(output_id)
1050+
1051+ # delayed_output_ids = np.array(tmp,dtype=np.int64)
1052+ # print(f"[prepare_for_decode] {tmp=}",flush=True)
1053+
10341054 delayed_output_ids = np .array (
10351055 [
10361056 (req .output_ids [- 1 ] if len (req .output_ids ) else req .origin_input_ids [- 1 ])
@@ -1042,8 +1062,25 @@ def prepare_for_decode(self):
10421062 else :
10431063 self .sampling_info .penalizer_orchestrator .cumulate_output_tokens (self .output_ids )
10441064
1065+ # output_ids = self.output_ids
1066+ # if self.enable_overlap:
1067+ # print(f"=======overlap========",flush=True)
1068+ # valid_output_ids = []
1069+ # for req in self.reqs:
1070+ # print(f"[for req] {req.next_token_ids_device=}, {req.output_ids=}",flush=True)
1071+ # if req.next_token_ids_device:
1072+ # output_id=jax.device_get(req.next_token_ids_device).tolist()[0]
1073+ # req.next_token_ids_device=None
1074+ # req.output_ids[-1]=output_id
1075+ # valid_output_ids.append(output_id)
1076+ # print(f"[next_token_ids_device 1] {req.output_ids=}",flush=True)
1077+ # output_ids = np.concat(valid_output_ids,dtype=np.int32)
1078+
10451079 # Update fields
10461080 self .input_ids = self .output_ids
1081+ # self.input_ids = output_ids
1082+
1083+ # print(f"[prepare_for_decode] {self.input_ids=}",flush=True)
10471084
10481085 self .output_ids = None
10491086
@@ -1215,7 +1252,9 @@ def get_model_worker_batch(
12151252 seq_lens_cpu = self .seq_lens
12161253 real_bs = len (seq_lens_cpu )
12171254 req_pool_indices_cpu = self .req_pool_indices
1218- token_indices_with_all_reqs = self .req_to_token_pool .req_to_token [self .req_pool_indices ]
1255+ token_indices_with_all_reqs = self .req_to_token_pool .req_to_token [
1256+ self .req_pool_indices
1257+ ] # cost in pathways, 23ms
12191258
12201259 # padding seq
12211260 # extend & decode: input_ids, positions, out_cache_loc, cache_loc
@@ -1313,6 +1352,8 @@ def get_model_worker_batch(
13131352
13141353 # Fill the array efficiently
13151354 offset = 0
1355+ ######################### cost in Pathways 10ms#####################
1356+ #####concurrecny=256,tp=4,page_size=256,max_running_requests=256
13161357 for i , (seq_idx , seq_len , aligned_len ) in enumerate (
13171358 zip (valid_indices , valid_seq_lens , aligned_lengths )
13181359 ):
@@ -1322,6 +1363,7 @@ def get_model_worker_batch(
13221363 ]
13231364 # Padding is already zero from initialization
13241365 offset += aligned_len
1366+ ######################### cost in Pathways#####################
13251367
13261368 offset = np .sum (seq_lens_cpu [seq_lens_cpu > 0 ]) if len (seq_lens_cpu ) > 0 else 0
13271369
@@ -1335,6 +1377,7 @@ def get_model_worker_batch(
13351377 if len (cache_loc_flat ) < total_cache_loc_size :
13361378 cache_loc_cpu [len (cache_loc_flat ) :] = 0
13371379
1380+ ####################cost in Pathways 22ms######################
13381381 if bs_padding_size > 0 :
13391382 invalid_req_pool_indices = np .array (
13401383 [- 1 ] * bs_padding_size , dtype = req_pool_indices_cpu .dtype
@@ -1365,6 +1408,8 @@ def get_model_worker_batch(
13651408 [extend_logprob_start_lens , invalid_extend_logprob_start_lens ], axis = 0
13661409 )
13671410
1411+ ############################################################################
1412+
13681413 sampling_info = self .sampling_info
13691414 if self .sampling_info :
13701415 new_temperatures = np .concatenate (
0 commit comments