Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 9bcff5c

Browse files
authored
Fixed generator (#34)
1 parent e30a820 commit 9bcff5c

File tree

1 file changed

+33
-33
lines changed
  • kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/generator

1 file changed

+33
-33
lines changed

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/generator/generator.py

+33-33
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from abc import ABC
44
from dataclasses import dataclass
55
from pathlib import Path
6-
from typing import Iterable, List, Set, Type, Pattern, Tuple
6+
from typing import Iterable, List, Set, Type, Pattern, Tuple, Optional
77

88
import torch
9+
from kilroy_module_server_py_sdk import Configurable, Parameter, classproperty
910
from torch import Tensor
1011
from torch.distributions import Categorical
1112

@@ -25,7 +26,6 @@
2526
pack_list,
2627
batched_forward,
2728
)
28-
from kilroy_module_server_py_sdk import Configurable, Parameter, classproperty
2929

3030

3131
@dataclass
@@ -197,57 +197,50 @@ def _update_generation_state(
197197
return state
198198

199199
@staticmethod
200-
def _is_complete(sequence: SequenceState, end_value: int) -> bool:
201-
return (sequence.context + sequence.response)[-1] == end_value
202-
203-
@staticmethod
204-
def _trim_incomplete(
200+
def _trim_until_valid(
205201
sequence: SequenceState,
206202
tokenizer: Tokenizer,
207203
regex: Pattern[str],
208204
) -> SequenceState:
209205
for i in range(len(sequence.response) - 1, -1, -1):
210206
index = slice(0, i + 1)
211-
sentence = tokenizer.decode(sequence.response[index])
207+
sentence = tokenizer.decode(
208+
sequence.context + sequence.response[index]
209+
)
212210
if regex.fullmatch(sentence):
213211
return SequenceState(
214212
context=sequence.context,
215213
response=sequence.response[index],
216214
)
217-
for i in range(len(sequence.context) - 1, -1, -1):
218-
index = slice(0, i + 1)
219-
sentence = tokenizer.decode(sequence.context[index])
220-
if regex.fullmatch(sentence):
221-
return SequenceState(
222-
context=sequence.context[index], response=[]
223-
)
224-
return sequence
215+
216+
raise ValueError("No valid sentence found")
225217

226218
def _complete(
227219
self,
228220
state: GenerationState,
229221
tokenizer: Tokenizer,
230222
regex: Pattern[str],
231-
) -> List[SequenceState]:
223+
) -> List[Optional[SequenceState]]:
232224
in_sequences = state.finished_sequences + state.current_sequences
233225
out_sequences = []
234226

235227
for sequence in in_sequences:
236-
if self._is_complete(sequence, tokenizer.end_token):
237-
out_sequences.append(sequence)
238-
else:
239-
new_sequence = self._trim_incomplete(
240-
sequence, tokenizer, regex
241-
)
242-
out_sequences.append(new_sequence)
228+
try:
229+
sequence = self._trim_until_valid(sequence, tokenizer, regex)
230+
except ValueError:
231+
sequence = None
232+
out_sequences.append(sequence)
243233
return out_sequences
244234

245235
@staticmethod
246236
def _prepare_output(
247-
sequences: List[SequenceState],
248-
) -> List[Tuple[List[int], List[int]]]:
237+
sequences: List[Optional[SequenceState]],
238+
) -> List[Optional[Tuple[List[int], List[int]]]]:
249239
return [
250-
(sequence.context, sequence.response) for sequence in sequences
240+
(sequence.context, sequence.response)
241+
if sequence is not None
242+
else None
243+
for sequence in sequences
251244
]
252245

253246
async def _generate(
@@ -256,7 +249,7 @@ async def _generate(
256249
contexts: Iterable[Iterable[int]],
257250
max_length: int,
258251
regex: Pattern[str],
259-
) -> List[Tuple[List[int], List[int]]]:
252+
) -> List[Optional[Tuple[List[int], List[int]]]]:
260253
state = self._build_initial_generation_state(contexts)
261254
while not self._should_stop(state, max_length):
262255
logprobs = await self._predict(model, state.current_sequences)
@@ -272,14 +265,21 @@ async def generate(
272265
model: ModelInfo[SequentialModel],
273266
n: int,
274267
) -> List[Tuple[List[int], List[int]]]:
275-
async with self.state.read_lock() as state:
276-
contexts = self._sample_contexts(
277-
state.contexts, model.tokenizer, n
278-
)
268+
out = []
279269

280-
return await self._generate(
270+
while len(out) < n:
271+
async with self.state.read_lock() as state:
272+
contexts = self._sample_contexts(
273+
state.contexts, model.tokenizer, n - len(out)
274+
)
275+
sequences = await self._generate(
281276
model,
282277
contexts,
283278
state.max_length,
284279
state.regex,
285280
)
281+
out.extend(
282+
sequence for sequence in sequences if sequence is not None
283+
)
284+
285+
return out

0 commit comments

Comments
 (0)