Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit dcf517a

Browse files
authored
Added recoding sequences for reward model (#11)
1 parent 463d21a commit dcf517a

File tree

1 file changed

+16
-2
lines changed
  • kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules

1 file changed

+16
-2
lines changed

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,17 @@ def _fit_with_reward_model_batch(
268268
loss.backward()
269269
return scores.mean().item()
270270

271+
@staticmethod
272+
def _recode(
273+
sequences: PackedSequence, source: Tokenizer, target: Tokenizer
274+
) -> PackedSequence:
275+
sequences = unpack_to_list(sequences)
276+
sequences = [sequence.flatten().tolist() for sequence in sequences]
277+
decoded = [source.decode(sequence) for sequence in sequences]
278+
encoded = [target.encode(sequence) for sequence in decoded]
279+
encoded = [torch.tensor(sequence).view(-1, 1) for sequence in encoded]
280+
return pack_list(encoded)
281+
271282
async def _fit_supervised(
272283
self, data: AsyncIterable[Tuple[Tensor, Tensor]]
273284
) -> None:
@@ -324,8 +335,11 @@ async def _fit_with_reward_model(self) -> None:
324335
batch = await anext(generated)
325336
except StopAsyncIteration:
326337
break
327-
# TODO: recode
328-
sequences = batch.sequences
338+
sequences = self._recode(
339+
batch.sequences,
340+
state.language_model_tokenizer,
341+
state.reward_model_tokenizer,
342+
)
329343
logprobs = batch.logprobs
330344
score = await background(
331345
self._fit_with_reward_model_batch,

0 commit comments

Comments
 (0)