Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 463d21a

Browse files
authored
Added handling score in fit_posts (#10)
1 parent 9e43630 commit 463d21a

File tree

2 files changed

+52
-31
lines changed
  • kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules

2 files changed

+52
-31
lines changed

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/basic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,11 @@ def fit(model, batch):
167167
loss = await background(fit, state.model, batch)
168168
state.epoch_supervised_losses.append(loss)
169169

170-
async def fit_posts(self, posts: AsyncIterable[Dict[str, Any]]) -> None:
170+
async def fit_posts(
171+
self, posts: AsyncIterable[Tuple[Dict[str, Any], float]]
172+
) -> None:
171173
async def decoded():
172-
async for post in posts:
174+
async for post, _ in posts:
173175
# noinspection PyShadowingNames
174176
async with self.state.read_lock() as state:
175177
yield await state.codec.decode(state.tokenizer, post)

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+48-29
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Coroutine,
88
Dict,
99
Generator,
10-
Iterable,
1110
List,
1211
Set,
1312
Tuple,
@@ -34,7 +33,6 @@
3433
from torch.nn.utils.rnn import PackedSequence
3534

3635
from kilroy_module_pytorch_py_sdk.codec import Codec
37-
from kilroy_module_pytorch_py_sdk.generator import GenerationResult
3836
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
3937
from kilroy_module_pytorch_py_sdk.optimizers import Optimizer
4038
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
@@ -240,58 +238,76 @@ async def generate(
240238
yield post_id, post
241239

242240
@staticmethod
243-
def _fit_supervised_batch(
244-
model: LanguageModel, batch: Iterable[Tensor]
241+
def _fit_language_model_batch(
242+
model: LanguageModel, sequences: PackedSequence
245243
) -> float:
244+
batch = unpack_to_list(sequences)
246245
input = pack_list(truncate_last_element(batch))
247246
target = pack_list(truncate_first_element(batch))
248247
logprobs = model(input)
249248
loss = NLLLoss()(logprobs.data, target.data.flatten())
250249
loss.backward()
251250
return loss.item()
252251

253-
async def _fit_supervised(self, data: AsyncIterable[Tensor]) -> None:
252+
@staticmethod
253+
def _fit_reward_model_batch(
254+
model: RewardModel, sequences: PackedSequence, scores: Tensor
255+
) -> float:
256+
predicted = model(sequences)
257+
loss = MSELoss()(predicted, scores)
258+
loss.backward()
259+
return loss.item()
260+
261+
@staticmethod
262+
def _fit_with_reward_model_batch(
263+
model: RewardModel, sequences: PackedSequence, logprobs: Tensor
264+
) -> float:
265+
with freeze(model) as frozen:
266+
scores = frozen(sequences)
267+
loss = -(logprobs * scores).mean()
268+
loss.backward()
269+
return scores.mean().item()
270+
271+
async def _fit_supervised(
272+
self, data: AsyncIterable[Tuple[Tensor, Tensor]]
273+
) -> None:
254274
async with self.state.read_lock() as state:
255275
batches = stream.chunks(data, state.batch_size)
256276

257277
async with batches.stream() as streamer:
258278
async for batch in streamer:
259279
async with self.state.write_lock() as state:
280+
sequences = pack_list(sequence for sequence, _ in batch)
281+
scores = torch.vstack([score for _, score in batch])
260282
loss = await background(
261-
self._fit_supervised_batch, state.language_model, batch
283+
self._fit_language_model_batch,
284+
state.language_model,
285+
sequences,
262286
)
263287
state.epoch_supervised_losses.append(loss)
288+
loss = await background(
289+
self._fit_reward_model_batch,
290+
state.reward_model,
291+
sequences,
292+
scores,
293+
)
294+
state.epoch_reward_model_losses.append(loss)
264295

265-
async def fit_posts(self, posts: AsyncIterable[Dict[str, Any]]) -> None:
296+
async def fit_posts(
297+
self, posts: AsyncIterable[Tuple[Dict[str, Any], float]]
298+
) -> None:
266299
async def decoded():
267-
async for post in posts:
300+
async for post, score in posts:
268301
# noinspection PyShadowingNames
269302
async with self.state.read_lock() as state:
270-
yield await state.codec.decode(
303+
post = await state.codec.decode(
271304
state.language_model_tokenizer, post
272305
)
306+
score = torch.tensor(score, dtype=torch.float)
307+
yield post, score
273308

274309
await self._fit_supervised(decoded())
275310

276-
@staticmethod
277-
def _fit_reward_model_batch(
278-
model: RewardModel, sequences: PackedSequence, scores: Tensor
279-
) -> float:
280-
predicted = model(sequences)
281-
loss = MSELoss()(predicted, scores)
282-
loss.backward()
283-
return loss.item()
284-
285-
@staticmethod
286-
def _fit_with_reward_model_batch(
287-
model: RewardModel, batch: GenerationResult
288-
) -> float:
289-
with freeze(model) as frozen:
290-
scores = frozen(batch.sequences)
291-
loss = -(batch.logprobs * scores).mean()
292-
loss.backward()
293-
return scores.mean().item()
294-
295311
async def _fit_with_reward_model(self) -> None:
296312
async with self.state.read_lock() as state:
297313
generated = state.backend_generator.generate(
@@ -309,10 +325,13 @@ async def _fit_with_reward_model(self) -> None:
309325
except StopAsyncIteration:
310326
break
311327
# TODO: recode
328+
sequences = batch.sequences
329+
logprobs = batch.logprobs
312330
score = await background(
313331
self._fit_with_reward_model_batch,
314332
state.reward_model,
315-
batch,
333+
sequences,
334+
logprobs,
316335
)
317336
state.epoch_reward_model_scores.append(score)
318337

0 commit comments

Comments
 (0)