|
10 | 10 | from sglang_omni_v1.model_runner.base import ModelRunner |
11 | 11 |
|
12 | 12 |
|
| 13 | +def collect_s2pro_step_outputs( |
| 14 | + result: Any, |
| 15 | + requests: list, |
| 16 | + *, |
| 17 | + output_codes: torch.Tensor, |
| 18 | + output_semantic_ids: torch.Tensor, |
| 19 | + im_end_token_id: int, |
| 20 | +) -> None: |
| 21 | + batch_size = len(requests) |
| 22 | + if batch_size == 0: |
| 23 | + return |
| 24 | + |
| 25 | + result.next_token_ids = output_semantic_ids[:batch_size].clone() |
| 26 | + semantic_tokens = output_semantic_ids[:batch_size].tolist() |
| 27 | + |
| 28 | + for row_idx, sched_req in enumerate(requests): |
| 29 | + data = sched_req.data |
| 30 | + if data.req.is_chunked > 0: |
| 31 | + continue |
| 32 | + |
| 33 | + semantic_token = semantic_tokens[row_idx] |
| 34 | + if semantic_token == im_end_token_id: |
| 35 | + continue |
| 36 | + |
| 37 | + codes = output_codes[row_idx].unsqueeze(-1).clone() |
| 38 | + data.last_codebook_values = codes[1:, 0].clone() |
| 39 | + data.previous_semantic_tokens.append(semantic_token) |
| 40 | + data.output_codes.append(codes) |
| 41 | + |
| 42 | + |
13 | 43 | class FishS2ProModelRunner(ModelRunner): |
14 | 44 | """Fish TTS runner with unified forward-owned decode and persistent buffers.""" |
15 | 45 |
|
16 | 46 | def __init__(self, tp_worker: Any, output_processor: Any): |
17 | 47 | super().__init__(tp_worker, output_processor) |
18 | 48 | self._semantic_begin_id = int(self.model._semantic_begin_id) |
19 | 49 | self._semantic_end_id = int(self.model._semantic_end_id) |
| 50 | + self._im_end_token_id = int(self.model._im_end_token_id) |
20 | 51 |
|
21 | 52 | def prepare_prefill(self, forward_batch, schedule_batch, requests): |
22 | 53 | del schedule_batch |
@@ -117,19 +148,10 @@ def _build_prefill_input_embeds( |
117 | 148 | return text_embeds |
118 | 149 |
|
119 | 150 | def _collect_step_outputs(self, result: Any, requests: list) -> None: |
120 | | - batch_size = len(requests) |
121 | | - if batch_size == 0: |
122 | | - return |
123 | | - |
124 | | - result.next_token_ids = self.model._output_semantic_ids[:batch_size].clone() |
125 | | - |
126 | | - for row_idx, sched_req in enumerate(requests): |
127 | | - data = sched_req.data |
128 | | - req = data.req |
129 | | - if req.is_chunked > 0: |
130 | | - continue |
131 | | - |
132 | | - codes = self.model._output_codes[row_idx].unsqueeze(-1).clone() |
133 | | - data.last_codebook_values = codes[1:, 0].clone() |
134 | | - data.previous_semantic_tokens.append(int(codes[0, -1].item())) |
135 | | - data.output_codes.append(codes) |
| 151 | + collect_s2pro_step_outputs( |
| 152 | + result, |
| 153 | + requests, |
| 154 | + output_codes=self.model._output_codes, |
| 155 | + output_semantic_ids=self.model._output_semantic_ids, |
| 156 | + im_end_token_id=self._im_end_token_id, |
| 157 | + ) |
0 commit comments