1
1
from dataclasses import dataclass , field
2
- from typing import Iterable , List , Tuple
2
+ from typing import Iterable , List , Tuple , Optional , Pattern
3
3
4
4
import torch
5
5
from kilroy_module_server_py_sdk import background
8
8
9
9
from kilroy_module_pytorch_py_sdk .models import LanguageModel
10
10
from kilroy_module_pytorch_py_sdk .samplers .base import Sampler
11
+ from kilroy_module_pytorch_py_sdk .tokenizer import Tokenizer
11
12
from kilroy_module_pytorch_py_sdk .utils import pack_list , unpack_to_padded
12
13
13
14
@@ -41,7 +42,7 @@ def _build_initial_state(contexts: Iterable[Iterable[int]]) -> GenerationState:
41
42
return GenerationState (
42
43
waiting_sequences = waiting ,
43
44
current_sequences = current ,
44
- current_logprobs = [torch .tensor (0 ) for _ in range (len (current ))],
45
+ current_logprobs = [torch .tensor ([[ 0 ]] ) for _ in range (len (current ))],
45
46
current_max_length = min_length ,
46
47
)
47
48
@@ -77,18 +78,18 @@ def _update_state(
77
78
state : GenerationState ,
78
79
next_values : Iterable [Tensor ],
79
80
next_logprobs : Iterable [Tensor ],
80
- end_value : int ,
81
+ tokenizer : Tokenizer ,
81
82
) -> GenerationState :
82
83
sequences = [
83
84
torch .cat ((current , next .view (1 , 1 )))
84
85
for current , next in zip (state .current_sequences , next_values )
85
86
]
86
87
logprobs = [
87
- torch .add ( current , next )
88
+ torch .cat (( current , next . view ( 1 , 1 )) )
88
89
for current , next in zip (state .current_logprobs , next_logprobs )
89
90
]
90
91
91
- finished_mask = _get_finished_mask (next_values , end_value )
92
+ finished_mask = _get_finished_mask (next_values , tokenizer . end_token )
92
93
93
94
state .finished_sequences .extend (
94
95
[
@@ -121,7 +122,7 @@ def _update_state(
121
122
for sequence in state .waiting_sequences :
122
123
if len (sequence ) == new_current_max_length :
123
124
new_current_sequences .append (sequence )
124
- new_current_logprobs .append (torch .tensor (0 ))
125
+ new_current_logprobs .append (torch .tensor ([[ 0 ]] ))
125
126
else :
126
127
new_waiting_sequences .append (sequence )
127
128
@@ -133,18 +134,57 @@ def _update_state(
133
134
return state
134
135
135
136
137
+ def _is_complete (sequence : Tensor , end_value : int ) -> bool :
138
+ return sequence [- 1 ].item () == end_value
139
+
140
+
141
+ def _trim_incomplete (
142
+ sequence : Tensor ,
143
+ logprobs : Tensor ,
144
+ tokenizer : Tokenizer ,
145
+ regex : Pattern [str ],
146
+ ) -> Tuple [Tensor , Tensor ]:
147
+ for i in range (len (sequence ) - 1 , - 1 , - 1 ):
148
+ index = slice (0 , i + 1 )
149
+ sentence = tokenizer .decode (sequence [index ].flatten ().tolist ())
150
+ if regex .fullmatch (sentence ):
151
+ return sequence [index ], logprobs [index ]
152
+ return sequence , logprobs
153
+
154
+
155
+ def _cleanup_incomplete (
156
+ sequence : Tensor ,
157
+ logprobs : Tensor ,
158
+ tokenizer : Tokenizer ,
159
+ regex : Pattern [str ],
160
+ ) -> Tuple [Tensor , Tensor ]:
161
+ new_sequence , new_logprobs = _trim_incomplete (
162
+ sequence [:- 1 ], logprobs [:- 1 ], tokenizer , regex
163
+ )
164
+ new_sequence = torch .cat (
165
+ (new_sequence , torch .tensor ([[tokenizer .end_token ]]))
166
+ )
167
+ return new_sequence , new_logprobs
168
+
169
+
136
170
def _complete (
137
- state : GenerationState , end_value : int
171
+ state : GenerationState , tokenizer : Tokenizer , regex : Pattern [ str ]
138
172
) -> Tuple [List [Tensor ], List [Tensor ]]:
139
- sequences = state .finished_sequences + state .current_sequences
140
- sequences = [
141
- torch .cat ((sequence [:- 1 ], torch .tensor ([[end_value ]])))
142
- if sequence [- 1 ].item () != end_value
143
- else sequence
144
- for sequence in sequences
145
- ]
146
- logprobs = state .finished_logprobs + state .current_logprobs
147
- return sequences , logprobs
173
+ in_sequences = state .finished_sequences + state .current_sequences
174
+ in_logprobs = state .finished_logprobs + state .current_logprobs
175
+ out_sequences , out_logprobs = [], []
176
+
177
+ for sequence , logprobs in zip (in_sequences , in_logprobs ):
178
+ if _is_complete (sequence , tokenizer .end_token ):
179
+ out_sequences .append (sequence )
180
+ out_logprobs .append (logprobs )
181
+ else :
182
+ new_sequence , new_logprobs = _cleanup_incomplete (
183
+ sequence , logprobs , tokenizer , regex
184
+ )
185
+ out_sequences .append (new_sequence )
186
+ out_logprobs .append (new_logprobs )
187
+ return out_sequences , out_logprobs
148
188
149
189
150
190
def _prepare_output (
@@ -156,7 +196,7 @@ def _prepare_output(
156
196
reverse = True ,
157
197
)
158
198
sequences = pack_list ([sequence for sequence , _ in ordered ])
159
- logprobs = torch .vstack ([logprob for _ , logprob in ordered ])
199
+ logprobs = torch .vstack ([logprob . sum () for _ , logprob in ordered ])
160
200
return GenerationResult (sequences = sequences , logprobs = logprobs )
161
201
162
202
@@ -165,12 +205,13 @@ async def generate(
165
205
sampler : Sampler ,
166
206
contexts : Iterable [Iterable [int ]],
167
207
max_length : int ,
168
- end_value : int ,
208
+ tokenizer : Tokenizer ,
209
+ regex : Pattern [str ],
169
210
) -> GenerationResult :
170
211
state = _build_initial_state (contexts )
171
212
while not _should_stop (state , max_length ):
172
213
logprobs = await background (_predict , model , state .current_sequences )
173
214
next_values , next_logprobs = await _pick (sampler , logprobs )
174
- state = _update_state (state , next_values , next_logprobs , end_value )
175
- sequences , logprobs = _complete (state , end_value )
215
+ state = _update_state (state , next_values , next_logprobs , tokenizer )
216
+ sequences , logprobs = _complete (state , tokenizer , regex )
176
217
return _prepare_output (sequences , logprobs )
0 commit comments