Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 8e86cb8

Browse files
authored
Added sentence cleanup (#27)
1 parent 3428037 commit 8e86cb8

File tree

3 files changed

+109
-29
lines changed

3 files changed

+109
-29
lines changed

kilroy_module_pytorch_py_sdk/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.7.1"
3+
version = "0.7.2"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/generator/generator.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import json
22
import random
3+
import re
34
from dataclasses import dataclass
45
from functools import partial
56
from pathlib import Path
6-
from typing import Any, AsyncIterable, Dict, Iterable, List, Set
7+
from typing import (
8+
Any,
9+
AsyncIterable,
10+
Dict,
11+
Iterable,
12+
List,
13+
Set,
14+
Pattern,
15+
Callable,
16+
Awaitable,
17+
)
718

819
from kilroy_module_server_py_sdk import (
920
CategorizableBasedParameter,
@@ -26,16 +37,18 @@
2637
class Params(SerializableModel):
2738
sampler_type: str = "epsilonNucleus"
2839
samplers_params: Dict[str, Dict[str, Any]] = {}
29-
contexts: List[str] = [""]
30-
max_length: int
31-
batch_size: int
40+
contexts: List[str] = []
41+
regex: str = r"^(^(?!.*\s+[a-zA-Z0-9_']*$).+$)|(^(?!.*[\.\?!]+).+$)$"
42+
max_length: int = 16
43+
batch_size: int = 1
3244

3345

3446
@dataclass
3547
class State:
3648
sampler: Sampler
3749
samplers_params: Dict[str, Dict[str, Any]]
3850
contexts: List[str]
51+
regex: Pattern[str]
3952
max_length: int
4053
batch_size: int
4154

@@ -50,12 +63,33 @@ def schema(cls) -> Dict[str, Any]:
5063
return {
5164
"type": "array",
5265
"items": {"type": "string"},
53-
"minItems": 1,
5466
"title": cls.pretty_name,
55-
"default": [" "],
67+
"default": [],
5668
}
5769

5870

71+
class RegexParameter(Parameter[State, str]):
72+
async def _get(self, state: State) -> str:
73+
return state.regex.pattern
74+
75+
async def _set(self, state: State, value: str) -> Callable[[], Awaitable]:
76+
original_value = state.regex
77+
78+
async def undo():
79+
state.regex = original_value
80+
81+
state.regex = re.compile(value)
82+
return undo
83+
84+
@classproperty
85+
def schema(cls) -> Dict[str, Any]:
86+
return {"type": "string", "title": cls.pretty_name}
87+
88+
@classproperty
89+
def pretty_name(cls) -> str:
90+
return "Regex"
91+
92+
5993
class MaxLengthParameter(Parameter[State, int]):
6094
@classproperty
6195
def schema(cls) -> Dict[str, Any]:
@@ -78,6 +112,7 @@ def parameters(cls) -> Set[Parameter]:
78112
return {
79113
SamplerParameter(),
80114
ContextsParameter(),
115+
RegexParameter(),
81116
MaxLengthParameter(),
82117
BatchSizeParameter(),
83118
}
@@ -95,6 +130,7 @@ async def _build_default_state(self) -> State:
95130
sampler=await self._build_sampler(params),
96131
samplers_params=params.samplers_params,
97132
contexts=params.contexts,
133+
regex=re.compile(params.regex),
98134
max_length=params.max_length,
99135
batch_size=params.batch_size,
100136
)
@@ -104,6 +140,7 @@ async def _save_state(self, state: State, directory: Path) -> None:
104140
"sampler_type": state.sampler.category,
105141
"samplers_params": state.samplers_params,
106142
"contexts": state.contexts,
143+
"regex": state.regex.pattern,
107144
"max_length": state.max_length,
108145
"batch_size": state.batch_size,
109146
}
@@ -126,6 +163,7 @@ async def _load_saved_state(self, directory: Path) -> State:
126163
),
127164
samplers_params=state_dict["samplers_params"],
128165
contexts=state_dict["contexts"],
166+
regex=re.compile(state_dict["regex"]),
129167
max_length=state_dict["max_length"],
130168
batch_size=state_dict["batch_size"],
131169
)
@@ -139,7 +177,7 @@ async def cleanup(self) -> None:
139177
def _get_contexts(
140178
state: State, tokenizer: Tokenizer, n: int
141179
) -> Iterable[List[int]]:
142-
contexts = random.choices(state.contexts, k=n)
180+
contexts = random.choices(state.contexts or [""], k=n)
143181

144182
for context in contexts:
145183
encoded = tokenizer.encode(context)
@@ -162,5 +200,6 @@ async def generate(
162200
state.sampler,
163201
contexts,
164202
state.max_length,
165-
tokenizer.end_token,
203+
tokenizer,
204+
state.regex,
166205
)

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/generator/utils.py

+61-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Iterable, List, Tuple
2+
from typing import Iterable, List, Tuple, Optional, Pattern
33

44
import torch
55
from kilroy_module_server_py_sdk import background
@@ -8,6 +8,7 @@
88

99
from kilroy_module_pytorch_py_sdk.models import LanguageModel
1010
from kilroy_module_pytorch_py_sdk.samplers.base import Sampler
11+
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
1112
from kilroy_module_pytorch_py_sdk.utils import pack_list, unpack_to_padded
1213

1314

@@ -41,7 +42,7 @@ def _build_initial_state(contexts: Iterable[Iterable[int]]) -> GenerationState:
4142
return GenerationState(
4243
waiting_sequences=waiting,
4344
current_sequences=current,
44-
current_logprobs=[torch.tensor(0) for _ in range(len(current))],
45+
current_logprobs=[torch.tensor([[0]]) for _ in range(len(current))],
4546
current_max_length=min_length,
4647
)
4748

@@ -77,18 +78,18 @@ def _update_state(
7778
state: GenerationState,
7879
next_values: Iterable[Tensor],
7980
next_logprobs: Iterable[Tensor],
80-
end_value: int,
81+
tokenizer: Tokenizer,
8182
) -> GenerationState:
8283
sequences = [
8384
torch.cat((current, next.view(1, 1)))
8485
for current, next in zip(state.current_sequences, next_values)
8586
]
8687
logprobs = [
87-
torch.add(current, next)
88+
torch.cat((current, next.view(1, 1)))
8889
for current, next in zip(state.current_logprobs, next_logprobs)
8990
]
9091

91-
finished_mask = _get_finished_mask(next_values, end_value)
92+
finished_mask = _get_finished_mask(next_values, tokenizer.end_token)
9293

9394
state.finished_sequences.extend(
9495
[
@@ -121,7 +122,7 @@ def _update_state(
121122
for sequence in state.waiting_sequences:
122123
if len(sequence) == new_current_max_length:
123124
new_current_sequences.append(sequence)
124-
new_current_logprobs.append(torch.tensor(0))
125+
new_current_logprobs.append(torch.tensor([[0]]))
125126
else:
126127
new_waiting_sequences.append(sequence)
127128

@@ -133,18 +134,57 @@ def _update_state(
133134
return state
134135

135136

137+
def _is_complete(sequence: Tensor, end_value: int) -> bool:
138+
return sequence[-1].item() == end_value
139+
140+
141+
def _trim_incomplete(
142+
sequence: Tensor,
143+
logprobs: Tensor,
144+
tokenizer: Tokenizer,
145+
regex: Pattern[str],
146+
) -> Tuple[Tensor, Tensor]:
147+
for i in range(len(sequence) - 1, -1, -1):
148+
index = slice(0, i + 1)
149+
sentence = tokenizer.decode(sequence[index].flatten().tolist())
150+
if regex.fullmatch(sentence):
151+
return sequence[index], logprobs[index]
152+
return sequence, logprobs
153+
154+
155+
def _cleanup_incomplete(
156+
sequence: Tensor,
157+
logprobs: Tensor,
158+
tokenizer: Tokenizer,
159+
regex: Pattern[str],
160+
) -> Tuple[Tensor, Tensor]:
161+
new_sequence, new_logprobs = _trim_incomplete(
162+
sequence[:-1], logprobs[:-1], tokenizer, regex
163+
)
164+
new_sequence = torch.cat(
165+
(new_sequence, torch.tensor([[tokenizer.end_token]]))
166+
)
167+
return new_sequence, new_logprobs
168+
169+
136170
def _complete(
137-
state: GenerationState, end_value: int
171+
state: GenerationState, tokenizer: Tokenizer, regex: Pattern[str]
138172
) -> Tuple[List[Tensor], List[Tensor]]:
139-
sequences = state.finished_sequences + state.current_sequences
140-
sequences = [
141-
torch.cat((sequence[:-1], torch.tensor([[end_value]])))
142-
if sequence[-1].item() != end_value
143-
else sequence
144-
for sequence in sequences
145-
]
146-
logprobs = state.finished_logprobs + state.current_logprobs
147-
return sequences, logprobs
173+
in_sequences = state.finished_sequences + state.current_sequences
174+
in_logprobs = state.finished_logprobs + state.current_logprobs
175+
out_sequences, out_logprobs = [], []
176+
177+
for sequence, logprobs in zip(in_sequences, in_logprobs):
178+
if _is_complete(sequence, tokenizer.end_token):
179+
out_sequences.append(sequence)
180+
out_logprobs.append(logprobs)
181+
else:
182+
new_sequence, new_logprobs = _cleanup_incomplete(
183+
sequence, logprobs, tokenizer, regex
184+
)
185+
out_sequences.append(new_sequence)
186+
out_logprobs.append(new_logprobs)
187+
return out_sequences, out_logprobs
148188

149189

150190
def _prepare_output(
@@ -156,7 +196,7 @@ def _prepare_output(
156196
reverse=True,
157197
)
158198
sequences = pack_list([sequence for sequence, _ in ordered])
159-
logprobs = torch.vstack([logprob for _, logprob in ordered])
199+
logprobs = torch.vstack([logprob.sum() for _, logprob in ordered])
160200
return GenerationResult(sequences=sequences, logprobs=logprobs)
161201

162202

@@ -165,12 +205,13 @@ async def generate(
165205
sampler: Sampler,
166206
contexts: Iterable[Iterable[int]],
167207
max_length: int,
168-
end_value: int,
208+
tokenizer: Tokenizer,
209+
regex: Pattern[str],
169210
) -> GenerationResult:
170211
state = _build_initial_state(contexts)
171212
while not _should_stop(state, max_length):
172213
logprobs = await background(_predict, model, state.current_sequences)
173214
next_values, next_logprobs = await _pick(sampler, logprobs)
174-
state = _update_state(state, next_values, next_logprobs, end_value)
175-
sequences, logprobs = _complete(state, end_value)
215+
state = _update_state(state, next_values, next_logprobs, tokenizer)
216+
sequences, logprobs = _complete(state, tokenizer, regex)
176217
return _prepare_output(sequences, logprobs)

0 commit comments

Comments
 (0)