7
7
Coroutine ,
8
8
Dict ,
9
9
Generator ,
10
- Iterable ,
11
10
List ,
12
11
Set ,
13
12
Tuple ,
34
33
from torch .nn .utils .rnn import PackedSequence
35
34
36
35
from kilroy_module_pytorch_py_sdk .codec import Codec
37
- from kilroy_module_pytorch_py_sdk .generator import GenerationResult
38
36
from kilroy_module_pytorch_py_sdk .models import LanguageModel , RewardModel
39
37
from kilroy_module_pytorch_py_sdk .optimizers import Optimizer
40
38
from kilroy_module_pytorch_py_sdk .tokenizer import Tokenizer
@@ -240,58 +238,76 @@ async def generate(
240
238
yield post_id , post
241
239
242
240
@staticmethod
243
- def _fit_supervised_batch (
244
- model : LanguageModel , batch : Iterable [ Tensor ]
241
+ def _fit_language_model_batch (
242
+ model : LanguageModel , sequences : PackedSequence
245
243
) -> float :
244
+ batch = unpack_to_list (sequences )
246
245
input = pack_list (truncate_last_element (batch ))
247
246
target = pack_list (truncate_first_element (batch ))
248
247
logprobs = model (input )
249
248
loss = NLLLoss ()(logprobs .data , target .data .flatten ())
250
249
loss .backward ()
251
250
return loss .item ()
252
251
253
- async def _fit_supervised (self , data : AsyncIterable [Tensor ]) -> None :
252
+ @staticmethod
253
+ def _fit_reward_model_batch (
254
+ model : RewardModel , sequences : PackedSequence , scores : Tensor
255
+ ) -> float :
256
+ predicted = model (sequences )
257
+ loss = MSELoss ()(predicted , scores )
258
+ loss .backward ()
259
+ return loss .item ()
260
+
261
+ @staticmethod
262
+ def _fit_with_reward_model_batch (
263
+ model : RewardModel , sequences : PackedSequence , logprobs : Tensor
264
+ ) -> float :
265
+ with freeze (model ) as frozen :
266
+ scores = frozen (sequences )
267
+ loss = - (logprobs * scores ).mean ()
268
+ loss .backward ()
269
+ return scores .mean ().item ()
270
+
271
+ async def _fit_supervised (
272
+ self , data : AsyncIterable [Tuple [Tensor , Tensor ]]
273
+ ) -> None :
254
274
async with self .state .read_lock () as state :
255
275
batches = stream .chunks (data , state .batch_size )
256
276
257
277
async with batches .stream () as streamer :
258
278
async for batch in streamer :
259
279
async with self .state .write_lock () as state :
280
+ sequences = pack_list (sequence for sequence , _ in batch )
281
+ scores = torch .vstack ([score for _ , score in batch ])
260
282
loss = await background (
261
- self ._fit_supervised_batch , state .language_model , batch
283
+ self ._fit_language_model_batch ,
284
+ state .language_model ,
285
+ sequences ,
262
286
)
263
287
state .epoch_supervised_losses .append (loss )
288
+ loss = await background (
289
+ self ._fit_reward_model_batch ,
290
+ state .reward_model ,
291
+ sequences ,
292
+ scores ,
293
+ )
294
+ state .epoch_reward_model_losses .append (loss )
264
295
265
- async def fit_posts (self , posts : AsyncIterable [Dict [str , Any ]]) -> None :
296
+ async def fit_posts (
297
+ self , posts : AsyncIterable [Tuple [Dict [str , Any ], float ]]
298
+ ) -> None :
266
299
async def decoded ():
267
- async for post in posts :
300
+ async for post , score in posts :
268
301
# noinspection PyShadowingNames
269
302
async with self .state .read_lock () as state :
270
- yield await state .codec .decode (
303
+ post = await state .codec .decode (
271
304
state .language_model_tokenizer , post
272
305
)
306
+ score = torch .tensor (score , dtype = torch .float )
307
+ yield post , score
273
308
274
309
await self ._fit_supervised (decoded ())
275
310
276
- @staticmethod
277
- def _fit_reward_model_batch (
278
- model : RewardModel , sequences : PackedSequence , scores : Tensor
279
- ) -> float :
280
- predicted = model (sequences )
281
- loss = MSELoss ()(predicted , scores )
282
- loss .backward ()
283
- return loss .item ()
284
-
285
- @staticmethod
286
- def _fit_with_reward_model_batch (
287
- model : RewardModel , batch : GenerationResult
288
- ) -> float :
289
- with freeze (model ) as frozen :
290
- scores = frozen (batch .sequences )
291
- loss = - (batch .logprobs * scores ).mean ()
292
- loss .backward ()
293
- return scores .mean ().item ()
294
-
295
311
async def _fit_with_reward_model (self ) -> None :
296
312
async with self .state .read_lock () as state :
297
313
generated = state .backend_generator .generate (
@@ -309,10 +325,13 @@ async def _fit_with_reward_model(self) -> None:
309
325
except StopAsyncIteration :
310
326
break
311
327
# TODO: recode
328
+ sequences = batch .sequences
329
+ logprobs = batch .logprobs
312
330
score = await background (
313
331
self ._fit_with_reward_model_batch ,
314
332
state .reward_model ,
315
- batch ,
333
+ sequences ,
334
+ logprobs ,
316
335
)
317
336
state .epoch_reward_model_scores .append (score )
318
337
0 commit comments