@@ -63,7 +63,7 @@ def config(cls) -> Dict[str, Any]:
63
63
return {
64
64
"type" : "line" ,
65
65
"data" : {"datasets" : [{"data" : []}]},
66
- "options" : {"parsing" : {"xAxisKey" : "epoch " , "yAxisKey" : "loss" }},
66
+ "options" : {"parsing" : {"xAxisKey" : "step " , "yAxisKey" : "loss" }},
67
67
}
68
68
69
69
@@ -81,7 +81,7 @@ def config(cls) -> Dict[str, Any]:
81
81
return {
82
82
"type" : "line" ,
83
83
"data" : {"datasets" : [{"data" : []}]},
84
- "options" : {"parsing" : {"xAxisKey" : "epoch " , "yAxisKey" : "score" }},
84
+ "options" : {"parsing" : {"xAxisKey" : "step " , "yAxisKey" : "score" }},
85
85
}
86
86
87
87
@@ -99,7 +99,7 @@ def config(cls) -> Dict[str, Any]:
99
99
return {
100
100
"type" : "line" ,
101
101
"data" : {"datasets" : [{"data" : []}]},
102
- "options" : {"parsing" : {"xAxisKey" : "epoch " , "yAxisKey" : "loss" }},
102
+ "options" : {"parsing" : {"xAxisKey" : "step " , "yAxisKey" : "loss" }},
103
103
}
104
104
105
105
@@ -117,7 +117,7 @@ def config(cls) -> Dict[str, Any]:
117
117
return {
118
118
"type" : "line" ,
119
119
"data" : {"datasets" : [{"data" : []}]},
120
- "options" : {"parsing" : {"xAxisKey" : "epoch " , "yAxisKey" : "score" }},
120
+ "options" : {"parsing" : {"xAxisKey" : "step " , "yAxisKey" : "score" }},
121
121
}
122
122
123
123
@@ -151,10 +151,10 @@ class MetricsState:
151
151
152
152
@dataclass
153
153
class ReportsState :
154
- epoch_supervised_losses : List [float ]
155
- epoch_reinforced_scores : List [float ]
156
- epoch_reward_model_losses : List [float ]
157
- epoch_reward_model_scores : List [float ]
154
+ step_supervised_losses : List [float ]
155
+ step_reinforced_scores : List [float ]
156
+ step_reward_model_losses : List [float ]
157
+ step_reward_model_scores : List [float ]
158
158
159
159
160
160
@dataclass
@@ -167,7 +167,7 @@ class State:
167
167
results_cache : Dict [UUID , Tuple [Tensor , Tensor ]]
168
168
batch_size : int
169
169
sample_size : int
170
- epoch : int
170
+ step : int
171
171
metrics : MetricsState
172
172
reports : ReportsState
173
173
coroutine_queue : Queue [Coroutine ]
@@ -357,14 +357,14 @@ async def _fit_supervised(
357
357
state .language_model .model ,
358
358
sequences ,
359
359
)
360
- state .reports .epoch_supervised_losses .append (loss )
360
+ state .reports .step_supervised_losses .append (loss )
361
361
loss = await background (
362
362
self ._fit_reward_model_batch ,
363
363
state .reward_model .model ,
364
364
sequences ,
365
365
scores ,
366
366
)
367
- state .reports .epoch_reward_model_losses .append (loss )
367
+ state .reports .step_reward_model_losses .append (loss )
368
368
369
369
async def fit_posts (
370
370
self , posts : AsyncIterable [Tuple [Dict [str , Any ], float ]]
@@ -409,7 +409,7 @@ async def _fit_with_reward_model(self) -> None:
409
409
sequences ,
410
410
logprobs ,
411
411
)
412
- state .reports .epoch_reward_model_scores .append (score )
412
+ state .reports .step_reward_model_scores .append (score )
413
413
414
414
async def _fit_reinforced (
415
415
self ,
@@ -429,8 +429,8 @@ async def _fit_reinforced(
429
429
sequences ,
430
430
scores ,
431
431
)
432
- state .reports .epoch_reward_model_losses .append (loss )
433
- state .reports .epoch_reinforced_scores .append (
432
+ state .reports .step_reward_model_losses .append (loss )
433
+ state .reports .step_reinforced_scores .append (
434
434
scores .mean ().item ()
435
435
)
436
436
@@ -448,19 +448,19 @@ async def get_results():
448
448
await self ._fit_reinforced (get_results ())
449
449
450
450
@staticmethod
451
- async def _report_mean_from_epoch (
452
- metric : Metric , epoch : int , label : str , values : Iterable [float ]
451
+ async def _report_mean_from_step (
452
+ metric : Metric , step : int , label : str , values : Iterable [float ]
453
453
) -> None :
454
454
values = list (values )
455
455
if values :
456
- await metric .report ({"epoch " : epoch , label : np .mean (values )})
456
+ await metric .report ({"step " : step , label : np .mean (values )})
457
457
458
458
@staticmethod
459
459
async def _reset_reports (state : State ) -> None :
460
- state .reports .epoch_supervised_losses = []
461
- state .reports .epoch_reinforced_scores = []
462
- state .reports .epoch_reward_model_losses = []
463
- state .reports .epoch_reward_model_scores = []
460
+ state .reports .step_supervised_losses = []
461
+ state .reports .step_reinforced_scores = []
462
+ state .reports .step_reward_model_losses = []
463
+ state .reports .step_reward_model_scores = []
464
464
465
465
async def step (self ) -> None :
466
466
async with self .state .write_lock () as state :
@@ -470,29 +470,29 @@ async def step(self) -> None:
470
470
await state .reward_model .optimizer .step ()
471
471
if state .reward_model .scheduler is not None :
472
472
await state .reward_model .scheduler .step ()
473
- await self ._report_mean_from_epoch (
473
+ await self ._report_mean_from_step (
474
474
state .metrics .supervised_loss_metric ,
475
- state .epoch ,
475
+ state .step ,
476
476
"loss" ,
477
- state .reports .epoch_supervised_losses ,
477
+ state .reports .step_supervised_losses ,
478
478
)
479
- await self ._report_mean_from_epoch (
479
+ await self ._report_mean_from_step (
480
480
state .metrics .reinforced_score_metric ,
481
- state .epoch ,
481
+ state .step ,
482
482
"score" ,
483
- state .reports .epoch_reinforced_scores ,
483
+ state .reports .step_reinforced_scores ,
484
484
)
485
- await self ._report_mean_from_epoch (
485
+ await self ._report_mean_from_step (
486
486
state .metrics .reward_model_loss_metric ,
487
- state .epoch ,
487
+ state .step ,
488
488
"loss" ,
489
- state .reports .epoch_reward_model_losses ,
489
+ state .reports .step_reward_model_losses ,
490
490
)
491
- await self ._report_mean_from_epoch (
491
+ await self ._report_mean_from_step (
492
492
state .metrics .reward_model_score_metric ,
493
- state .epoch ,
493
+ state .step ,
494
494
"score" ,
495
- state .reports .epoch_reward_model_scores ,
495
+ state .reports .step_reward_model_scores ,
496
496
)
497
497
await self ._reset_reports (state )
498
- state .epoch += 1
498
+ state .step += 1
0 commit comments