|
40 | 40 |
|
41 | 41 | DecodeState = Any |
42 | 42 | Params = Any |
| 43 | +PRNGKeyType = Any |
43 | 44 |
|
44 | 45 | log = logging.getLogger(__name__) |
45 | 46 |
|
@@ -130,19 +131,20 @@ def process( |
130 | 131 | input_true_length: int, |
131 | 132 | max_length: int, |
132 | 133 | prefill_done: Callable[[List[Tuple[engine_api.ResultTokens, int]], List[int], DecodeState], None], |
| 134 | + rng: PRNGKeyType, |
133 | 135 | ) -> None: |
134 | 136 | """Prefill helper process runner""" |
135 | 137 | padded_length = len(input_tokens_padded) |
136 | 138 | if self._type == "default": |
137 | 139 | first_token, decode_state = self._processor.process( |
138 | | - model_params, decode_state, decode_slot, input_tokens_padded, input_true_length |
| 140 | + model_params, decode_state, decode_slot, input_tokens_padded, input_true_length, rng |
139 | 141 | ) |
140 | 142 | prefill_done([(first_token, decode_slot)], [input_id], decode_state) |
141 | 143 | elif self._type == "batch": |
142 | 144 | if padded_length == max_length: |
143 | 145 | # fallback to default mode |
144 | 146 | first_token, decode_state = self._processor.process( |
145 | | - model_params, decode_state, decode_slot, input_tokens_padded, input_true_length |
| 147 | + model_params, decode_state, decode_slot, input_tokens_padded, input_true_length, rng |
146 | 148 | ) |
147 | 149 | prefill_done([(first_token, decode_slot)], [input_id], decode_state) |
148 | 150 | else: |
@@ -249,6 +251,9 @@ def batch_inference_with_callback( |
249 | 251 | counter = EventCounter(input=0, prefill=0, decode=0, detokenize=0) |
250 | 252 | dummy_length = 1 |
251 | 253 |
|
| 254 | + rng = jax.random.PRNGKey(1234) |
| 255 | + rng, _ = jax.random.split(rng) |
| 256 | + |
252 | 257 | def prefill_done(prefill_result, ids, decode_state): |
253 | 258 | nonlocal self |
254 | 259 | nonlocal counter |
@@ -345,7 +350,15 @@ def detokenize(): |
345 | 350 |
|
346 | 351 | # Do prefill when there are free slots |
347 | 352 | self.prefill.process( |
348 | | - self.params, self.decode_state, slot, row.id, row.tokens, row.true_length, self.max_prefill_length, prefill_done |
| 353 | + self.params, |
| 354 | + self.decode_state, |
| 355 | + slot, |
| 356 | + row.id, |
| 357 | + row.tokens, |
| 358 | + row.true_length, |
| 359 | + self.max_prefill_length, |
| 360 | + prefill_done, |
| 361 | + rng, |
349 | 362 | ) |
350 | 363 | self.prefill.finalize(self.params, self.decode_state, prefill_done) |
351 | 364 |
|
|
0 commit comments