3
3
from abc import ABC
4
4
from dataclasses import dataclass
5
5
from pathlib import Path
6
- from typing import Iterable , List , Set , Type , Pattern , Tuple
6
+ from typing import Iterable , List , Set , Type , Pattern , Tuple , Optional
7
7
8
8
import torch
9
+ from kilroy_module_server_py_sdk import Configurable , Parameter , classproperty
9
10
from torch import Tensor
10
11
from torch .distributions import Categorical
11
12
25
26
pack_list ,
26
27
batched_forward ,
27
28
)
28
- from kilroy_module_server_py_sdk import Configurable , Parameter , classproperty
29
29
30
30
31
31
@dataclass
@@ -197,57 +197,50 @@ def _update_generation_state(
197
197
return state
198
198
199
199
@staticmethod
200
- def _is_complete (sequence : SequenceState , end_value : int ) -> bool :
201
- return (sequence .context + sequence .response )[- 1 ] == end_value
202
-
203
- @staticmethod
204
- def _trim_incomplete (
200
+ def _trim_until_valid (
205
201
sequence : SequenceState ,
206
202
tokenizer : Tokenizer ,
207
203
regex : Pattern [str ],
208
204
) -> SequenceState :
209
205
for i in range (len (sequence .response ) - 1 , - 1 , - 1 ):
210
206
index = slice (0 , i + 1 )
211
- sentence = tokenizer .decode (sequence .response [index ])
207
+ sentence = tokenizer .decode (
208
+ sequence .context + sequence .response [index ]
209
+ )
212
210
if regex .fullmatch (sentence ):
213
211
return SequenceState (
214
212
context = sequence .context ,
215
213
response = sequence .response [index ],
216
214
)
217
- for i in range (len (sequence .context ) - 1 , - 1 , - 1 ):
218
- index = slice (0 , i + 1 )
219
- sentence = tokenizer .decode (sequence .context [index ])
220
- if regex .fullmatch (sentence ):
221
- return SequenceState (
222
- context = sequence .context [index ], response = []
223
- )
224
- return sequence
215
+
216
+ raise ValueError ("No valid sentence found" )
225
217
226
218
def _complete (
227
219
self ,
228
220
state : GenerationState ,
229
221
tokenizer : Tokenizer ,
230
222
regex : Pattern [str ],
231
- ) -> List [SequenceState ]:
223
+ ) -> List [Optional [ SequenceState ] ]:
232
224
in_sequences = state .finished_sequences + state .current_sequences
233
225
out_sequences = []
234
226
235
227
for sequence in in_sequences :
236
- if self ._is_complete (sequence , tokenizer .end_token ):
237
- out_sequences .append (sequence )
238
- else :
239
- new_sequence = self ._trim_incomplete (
240
- sequence , tokenizer , regex
241
- )
242
- out_sequences .append (new_sequence )
228
+ try :
229
+ sequence = self ._trim_until_valid (sequence , tokenizer , regex )
230
+ except ValueError :
231
+ sequence = None
232
+ out_sequences .append (sequence )
243
233
return out_sequences
244
234
245
235
@staticmethod
246
236
def _prepare_output (
247
- sequences : List [SequenceState ],
248
- ) -> List [Tuple [List [int ], List [int ]]]:
237
+ sequences : List [Optional [ SequenceState ] ],
238
+ ) -> List [Optional [ Tuple [List [int ], List [int ] ]]]:
249
239
return [
250
- (sequence .context , sequence .response ) for sequence in sequences
240
+ (sequence .context , sequence .response )
241
+ if sequence is not None
242
+ else None
243
+ for sequence in sequences
251
244
]
252
245
253
246
async def _generate (
@@ -256,7 +249,7 @@ async def _generate(
256
249
contexts : Iterable [Iterable [int ]],
257
250
max_length : int ,
258
251
regex : Pattern [str ],
259
- ) -> List [Tuple [List [int ], List [int ]]]:
252
+ ) -> List [Optional [ Tuple [List [int ], List [int ] ]]]:
260
253
state = self ._build_initial_generation_state (contexts )
261
254
while not self ._should_stop (state , max_length ):
262
255
logprobs = await self ._predict (model , state .current_sequences )
@@ -272,14 +265,21 @@ async def generate(
272
265
model : ModelInfo [SequentialModel ],
273
266
n : int ,
274
267
) -> List [Tuple [List [int ], List [int ]]]:
275
- async with self .state .read_lock () as state :
276
- contexts = self ._sample_contexts (
277
- state .contexts , model .tokenizer , n
278
- )
268
+ out = []
279
269
280
- return await self ._generate (
270
+ while len (out ) < n :
271
+ async with self .state .read_lock () as state :
272
+ contexts = self ._sample_contexts (
273
+ state .contexts , model .tokenizer , n - len (out )
274
+ )
275
+ sequences = await self ._generate (
281
276
model ,
282
277
contexts ,
283
278
state .max_length ,
284
279
state .regex ,
285
280
)
281
+ out .extend (
282
+ sequence for sequence in sequences if sequence is not None
283
+ )
284
+
285
+ return out
0 commit comments