Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 17ae277

Browse files
authored
Added basic error handling (#19)
1 parent 07b3f89 commit 17ae277

File tree

4 files changed

+75
-33
lines changed

4 files changed

+75
-33
lines changed

kilroy_module_pytorch_py_sdk/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.6.1"
3+
version = "0.6.2"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/basic.py

+38-15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import logging
13
from abc import ABC
24
from dataclasses import dataclass
35
from typing import (
@@ -43,6 +45,8 @@
4345
unpack_to_list,
4446
)
4547

48+
logger = logging.getLogger(__name__)
49+
4650

4751
class SupervisedLossMetric(Metric[Dict]):
4852
@classproperty
@@ -103,7 +107,6 @@ class State:
103107
generator: Generator
104108
codec: Codec
105109
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
106-
used_results: Set[UUID]
107110
batch_size: int
108111
step: int
109112
metrics: MetricsState
@@ -181,12 +184,19 @@ async def generate(
181184
async for result in generated:
182185
sequences = unpack_to_list(result.sequences)
183186
for sequence, logprob in zip(sequences, result.logprobs):
187+
184188
post_id = uuid4()
189+
185190
async with self.state.read_lock() as state:
186-
post = await state.codec.encode(state.tokenizer, sequence)
191+
codec = state.codec
192+
tokenizer = state.tokenizer
193+
194+
post = await codec.encode(tokenizer, sequence)
195+
187196
if not dry:
188197
async with self.state.write_lock() as state:
189198
state.results_cache[post_id] = (sequence, logprob[0])
199+
190200
yield post_id, post
191201

192202
async def _fit_supervised(self, data: AsyncIterable[Tensor]) -> None:
@@ -204,26 +214,38 @@ def fit(model, batch):
204214

205215
async with batches.stream() as streamer:
206216
async for batch in streamer:
207-
async with self.state.write_lock() as state:
208-
loss = await background(fit, state.model, batch)
209-
state.reports.step_supervised_losses.append(loss)
217+
if batch:
218+
async with self.state.write_lock() as state:
219+
loss = await background(fit, state.model, batch)
220+
state.reports.step_supervised_losses.append(loss)
210221

211222
async def fit_posts(
212223
self, posts: AsyncIterable[Tuple[Dict[str, Any], float]]
213224
) -> None:
214225
async def decoded():
215226
async for post, _ in posts:
216-
# noinspection PyShadowingNames
217227
async with self.state.read_lock() as state:
218-
yield await state.codec.decode(state.tokenizer, post)
228+
codec = state.codec
229+
tokenizer = state.tokenizer
230+
try:
231+
yield await codec.decode(tokenizer, post)
232+
except Exception as e:
233+
logger.warning(
234+
f"Failed to decode post: {json.dumps(post)}. Skipping...",
235+
exc_info=e,
236+
)
237+
continue
219238

220239
await self._fit_supervised(decoded())
221240

222241
async def _fit_reinforced(
223242
self,
224243
results: AsyncIterable[Tuple[Tensor, Tensor, Tensor]],
225244
) -> None:
226-
results = list([result async for result in results])
245+
results = [result async for result in results]
246+
if not results:
247+
return
248+
227249
logprobs = torch.stack([logprob for _, logprob, _ in results])
228250
scores = torch.stack([score for _, _, score in results])
229251

@@ -239,10 +261,13 @@ def fit():
239261
async def fit_scores(self, scores: List[Tuple[UUID, float]]) -> None:
240262
async def get_results():
241263
for post_id, score in scores:
242-
# noinspection PyShadowingNames
243264
async with self.state.write_lock() as state:
265+
if post_id not in state.results_cache:
266+
logger.warning(
267+
f"Post {str(post_id)} has not been generated. Skipping..."
268+
)
269+
continue
244270
sequence, logprob = state.results_cache.get(post_id)
245-
state.used_results.add(post_id)
246271
yield sequence, logprob, torch.tensor(score)
247272

248273
await self._fit_reinforced(get_results())
@@ -261,10 +286,8 @@ async def _reset_reports(state: State) -> None:
261286
state.reports.step_reinforced_scores = []
262287

263288
@staticmethod
264-
async def _delete_used_results(state: State) -> None:
265-
for post_id in state.used_results:
266-
state.results_cache.pop(post_id, None)
267-
state.used_results.clear()
289+
async def _delete_results(state: State) -> None:
290+
state.results_cache.clear()
268291

269292
async def step(self) -> None:
270293
async with self.state.write_lock() as state:
@@ -284,5 +307,5 @@ async def step(self) -> None:
284307
state.reports.step_reinforced_scores,
285308
)
286309
await self._reset_reports(state)
287-
await self._delete_used_results(state)
310+
await self._delete_results(state)
288311
state.step += 1

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import logging
13
from abc import ABC
24
from asyncio import Queue, Task
35
from dataclasses import dataclass
@@ -48,6 +50,8 @@
4850
unpack_to_list,
4951
)
5052

53+
logger = logging.getLogger(__name__)
54+
5155

5256
class SupervisedLossMetric(Metric[Dict]):
5357
@classproperty
@@ -165,7 +169,6 @@ class State:
165169
backend_generator: Generator
166170
codec: Codec
167171
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
168-
used_results: Set[UUID]
169172
batch_size: int
170173
sample_size: int
171174
step: int
@@ -290,14 +293,19 @@ async def generate(
290293
async for result in generated:
291294
sequences = unpack_to_list(result.sequences)
292295
for sequence, logprob in zip(sequences, result.logprobs):
296+
293297
post_id = uuid4()
298+
294299
async with self.state.read_lock() as state:
295-
post = await state.codec.encode(
296-
state.language_model.tokenizer, sequence
297-
)
300+
codec = state.codec
301+
tokenizer = state.language_model.tokenizer
302+
303+
post = await codec.encode(tokenizer, sequence)
304+
298305
if not dry:
299306
async with self.state.write_lock() as state:
300307
state.results_cache[post_id] = (sequence, logprob[0])
308+
301309
yield post_id, post
302310

303311
@staticmethod
@@ -350,6 +358,8 @@ async def _fit_supervised(
350358

351359
async with batches.stream() as streamer:
352360
async for batch in streamer:
361+
if not batch:
362+
continue
353363
async with self.state.write_lock() as state:
354364
sequences = pack_list(sequence for sequence, _ in batch)
355365
scores = torch.vstack([score for _, score in batch])
@@ -372,13 +382,19 @@ async def fit_posts(
372382
) -> None:
373383
async def decoded():
374384
async for post, score in posts:
375-
# noinspection PyShadowingNames
376385
async with self.state.read_lock() as state:
377-
post = await state.codec.decode(
378-
state.language_model.tokenizer, post
386+
codec = state.codec
387+
tokenizer = state.language_model.tokenizer
388+
try:
389+
post = await codec.decode(tokenizer, post)
390+
except Exception as e:
391+
logger.warning(
392+
f"Failed to decode post: {json.dumps(post)}. Skipping...",
393+
exc_info=e,
379394
)
380-
score = torch.tensor(score, dtype=torch.float)
381-
yield post, score
395+
continue
396+
score = torch.tensor(score, dtype=torch.float)
397+
yield post, score
382398

383399
await self._fit_supervised(decoded())
384400

@@ -421,6 +437,8 @@ async def _fit_reinforced(
421437

422438
async with batches.stream() as streamer:
423439
async for batch in streamer:
440+
if not batch:
441+
continue
424442
sequences = pack_list([sequence for sequence, _, _ in batch])
425443
scores = torch.vstack([score for _, _, score in batch])
426444
async with self.state.write_lock() as state:
@@ -441,10 +459,13 @@ async def _fit_reinforced(
441459
async def fit_scores(self, scores: List[Tuple[UUID, float]]) -> None:
442460
async def get_results():
443461
for post_id, score in scores:
444-
# noinspection PyShadowingNames
445462
async with self.state.write_lock() as state:
463+
if post_id not in state.results_cache:
464+
logger.warning(
465+
f"Post {str(post_id)} has not been generated. Skipping..."
466+
)
467+
continue
446468
sequence, logprob = state.results_cache.get(post_id)
447-
state.used_results.add(post_id)
448469
yield sequence, logprob, torch.tensor(score)
449470

450471
await self._fit_reinforced(get_results())
@@ -465,10 +486,8 @@ async def _reset_reports(state: State) -> None:
465486
state.reports.step_reward_model_scores = []
466487

467488
@staticmethod
468-
async def _delete_used_results(state: State) -> None:
469-
for post_id in state.used_results:
470-
state.results_cache.pop(post_id, None)
471-
state.used_results.clear()
489+
async def _delete_results(state: State) -> None:
490+
state.results_cache.clear()
472491

473492
async def step(self) -> None:
474493
async with self.state.write_lock() as state:
@@ -503,5 +522,5 @@ async def step(self) -> None:
503522
state.reports.step_reward_model_scores,
504523
)
505524
await self._reset_reports(state)
506-
await self._delete_used_results(state)
525+
await self._delete_results(state)
507526
state.step += 1

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[tool.poetry]
55
name = "kilroy-module-pytorch-py-sdk"
6-
version = "0.6.1"
6+
version = "0.6.2"
77
description = "SDK for kilroy modules using PyTorch 🧰"
88
readme = "kilroy_module_pytorch_py_sdk/README.md"
99
authors = ["kilroy <[email protected]>"]

0 commit comments

Comments
 (0)