@@ -76,11 +76,13 @@ def _apply_prefill_sampling_state(
7676 * ,
7777 sampling_params : SamplingParams ,
7878 prompt_tokens : torch .Tensor | None ,
79+ empty_slots : list [int ],
7980):
80- sampling_module = getattr (model_instance , "sampling_prefill " , None )
81+ sampling_module = getattr (model_instance , "sampling " , None )
8182 assert sampling_module is not None , "Sampling module not found in model for sampling on device."
8283 sampling_module .reset_sampling_params (sampling_params )
83- sampling_module .reset_seed (sampling_params .seed )
84+ sampling_module .seed_manager .reset_seed (sampling_params .seed , empty_slots )
85+ sampling_module .seed_manager .get_new_values (empty_slots , replicate_seeds = True )
8486 if prompt_tokens is not None :
8587 sampling_module .reset_prompt_tokens (prompt_tokens )
8688 sampling_module .reset_output_state ()
@@ -422,6 +424,7 @@ def prefill_forward_text(
422424 self .model [model_id ],
423425 sampling_params = per_request_params ,
424426 prompt_tokens = prefill_ids [:, :seq_len ].repeat (32 , 1 ),
427+ empty_slots = [user_id % 32 ],
425428 )
426429
427430 if enable_trace_current_prompt :
@@ -471,7 +474,7 @@ def prefill_forward_text(
471474 logits = self .model [model_id ].process_logits_after_prefill_trace (logits , last_token_idx )
472475
473476 if sampling_enabled :
474- tt_tokens , tt_log_probs = self .model [model_id ].sampling_prefill .sample (
477+ tt_tokens , tt_log_probs = self .model [model_id ].sampling .sample (
475478 logits ,
476479 enable_trace = False ,
477480 )
@@ -732,8 +735,8 @@ def decode_forward_text(
732735 sampling_module = getattr (self .model [i ], "sampling" , None )
733736 assert sampling_module is not None , "Sampling module not found in model for sampling on device."
734737 sampling_module .reset_sampling_params (formatted_params )
738+ sampling_module .seed_manager .get_new_values ()
735739 if reset_batch :
736- sampling_module .reset_seed (formatted_params .seed )
737740 sampling_module .reset_prompt_tokens (prompt_chunks [i ])
738741 sampling_module .reset_output_state (output_chunks [i ])
739742
0 commit comments