Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 07b3f89

Browse files
authored
Delayed results cleaning to step (#18)
1 parent f490608 commit 07b3f89

File tree

2 files changed

+20
-2
lines changed
  • kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules

2 files changed

+20
-2
lines changed

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/basic.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class State:
103103
generator: Generator
104104
codec: Codec
105105
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
106+
used_results: Set[UUID]
106107
batch_size: int
107108
step: int
108109
metrics: MetricsState
@@ -240,7 +241,8 @@ async def get_results():
240241
for post_id, score in scores:
241242
# noinspection PyShadowingNames
242243
async with self.state.write_lock() as state:
243-
sequence, logprob = state.results_cache.pop(post_id)
244+
sequence, logprob = state.results_cache.get(post_id)
245+
state.used_results.add(post_id)
244246
yield sequence, logprob, torch.tensor(score)
245247

246248
await self._fit_reinforced(get_results())
@@ -258,6 +260,12 @@ async def _reset_reports(state: State) -> None:
258260
state.reports.step_supervised_losses = []
259261
state.reports.step_reinforced_scores = []
260262

263+
@staticmethod
264+
async def _delete_used_results(state: State) -> None:
265+
for post_id in state.used_results:
266+
state.results_cache.pop(post_id, None)
267+
state.used_results.clear()
268+
261269
async def step(self) -> None:
262270
async with self.state.write_lock() as state:
263271
await state.optimizer.step()
@@ -276,4 +284,5 @@ async def step(self) -> None:
276284
state.reports.step_reinforced_scores,
277285
)
278286
await self._reset_reports(state)
287+
await self._delete_used_results(state)
279288
state.step += 1

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class State:
165165
backend_generator: Generator
166166
codec: Codec
167167
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
168+
used_results: Set[UUID]
168169
batch_size: int
169170
sample_size: int
170171
step: int
@@ -442,7 +443,8 @@ async def get_results():
442443
for post_id, score in scores:
443444
# noinspection PyShadowingNames
444445
async with self.state.write_lock() as state:
445-
sequence, logprob = state.results_cache.pop(post_id)
446+
sequence, logprob = state.results_cache.get(post_id)
447+
state.used_results.add(post_id)
446448
yield sequence, logprob, torch.tensor(score)
447449

448450
await self._fit_reinforced(get_results())
@@ -462,6 +464,12 @@ async def _reset_reports(state: State) -> None:
462464
state.reports.step_reward_model_losses = []
463465
state.reports.step_reward_model_scores = []
464466

467+
@staticmethod
468+
async def _delete_used_results(state: State) -> None:
469+
for post_id in state.used_results:
470+
state.results_cache.pop(post_id, None)
471+
state.used_results.clear()
472+
465473
async def step(self) -> None:
466474
async with self.state.write_lock() as state:
467475
await state.language_model.optimizer.step()
@@ -495,4 +503,5 @@ async def step(self) -> None:
495503
state.reports.step_reward_model_scores,
496504
)
497505
await self._reset_reports(state)
506+
await self._delete_used_results(state)
498507
state.step += 1

0 commit comments

Comments
 (0)