Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit f490608

Browse files
authored
Changed counting epochs to counting steps (#17)
1 parent aadcec0 commit f490608

File tree

4 files changed

+55
-55
lines changed

4 files changed

+55
-55
lines changed

kilroy_module_pytorch_py_sdk/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/basic.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def config(cls) -> Dict[str, Any]:
5858
return {
5959
"type": "line",
6060
"data": {"datasets": [{"data": []}]},
61-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "loss"}},
61+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "loss"}},
6262
}
6363

6464

@@ -76,7 +76,7 @@ def config(cls) -> Dict[str, Any]:
7676
return {
7777
"type": "line",
7878
"data": {"datasets": [{"data": []}]},
79-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "score"}},
79+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "score"}},
8080
}
8181

8282

@@ -88,8 +88,8 @@ class MetricsState:
8888

8989
@dataclass
9090
class ReportsState:
91-
epoch_supervised_losses: List[float]
92-
epoch_reinforced_scores: List[float]
91+
step_supervised_losses: List[float]
92+
step_reinforced_scores: List[float]
9393

9494

9595
@dataclass
@@ -104,7 +104,7 @@ class State:
104104
codec: Codec
105105
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
106106
batch_size: int
107-
epoch: int
107+
step: int
108108
metrics: MetricsState
109109
reports: ReportsState
110110

@@ -205,7 +205,7 @@ def fit(model, batch):
205205
async for batch in streamer:
206206
async with self.state.write_lock() as state:
207207
loss = await background(fit, state.model, batch)
208-
state.reports.epoch_supervised_losses.append(loss)
208+
state.reports.step_supervised_losses.append(loss)
209209

210210
async def fit_posts(
211211
self, posts: AsyncIterable[Tuple[Dict[str, Any], float]]
@@ -233,7 +233,7 @@ def fit():
233233

234234
async with self.state.write_lock() as state:
235235
score = await background(fit)
236-
state.reports.epoch_reinforced_scores.append(score)
236+
state.reports.step_reinforced_scores.append(score)
237237

238238
async def fit_scores(self, scores: List[Tuple[UUID, float]]) -> None:
239239
async def get_results():
@@ -246,34 +246,34 @@ async def get_results():
246246
await self._fit_reinforced(get_results())
247247

248248
@staticmethod
249-
async def _report_mean_from_epoch(
250-
metric: Metric, epoch: int, label: str, values: Iterable[float]
249+
async def _report_mean_from_step(
250+
metric: Metric, step: int, label: str, values: Iterable[float]
251251
) -> None:
252252
values = list(values)
253253
if values:
254-
await metric.report({"epoch": epoch, label: np.mean(values)})
254+
await metric.report({"step": step, label: np.mean(values)})
255255

256256
@staticmethod
257257
async def _reset_reports(state: State) -> None:
258-
state.reports.epoch_supervised_losses = []
259-
state.reports.epoch_reinforced_scores = []
258+
state.reports.step_supervised_losses = []
259+
state.reports.step_reinforced_scores = []
260260

261261
async def step(self) -> None:
262262
async with self.state.write_lock() as state:
263263
await state.optimizer.step()
264264
if state.scheduler is not None:
265265
await state.scheduler.step()
266-
await self._report_mean_from_epoch(
266+
await self._report_mean_from_step(
267267
state.metrics.supervised_loss_metric,
268-
state.epoch,
268+
state.step,
269269
"loss",
270-
state.reports.epoch_supervised_losses,
270+
state.reports.step_supervised_losses,
271271
)
272-
await self._report_mean_from_epoch(
272+
await self._report_mean_from_step(
273273
state.metrics.reinforced_score_metric,
274-
state.epoch,
274+
state.step,
275275
"score",
276-
state.reports.epoch_reinforced_scores,
276+
state.reports.step_reinforced_scores,
277277
)
278278
await self._reset_reports(state)
279-
state.epoch += 1
279+
state.step += 1

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def config(cls) -> Dict[str, Any]:
6363
return {
6464
"type": "line",
6565
"data": {"datasets": [{"data": []}]},
66-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "loss"}},
66+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "loss"}},
6767
}
6868

6969

@@ -81,7 +81,7 @@ def config(cls) -> Dict[str, Any]:
8181
return {
8282
"type": "line",
8383
"data": {"datasets": [{"data": []}]},
84-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "score"}},
84+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "score"}},
8585
}
8686

8787

@@ -99,7 +99,7 @@ def config(cls) -> Dict[str, Any]:
9999
return {
100100
"type": "line",
101101
"data": {"datasets": [{"data": []}]},
102-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "loss"}},
102+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "loss"}},
103103
}
104104

105105

@@ -117,7 +117,7 @@ def config(cls) -> Dict[str, Any]:
117117
return {
118118
"type": "line",
119119
"data": {"datasets": [{"data": []}]},
120-
"options": {"parsing": {"xAxisKey": "epoch", "yAxisKey": "score"}},
120+
"options": {"parsing": {"xAxisKey": "step", "yAxisKey": "score"}},
121121
}
122122

123123

@@ -151,10 +151,10 @@ class MetricsState:
151151

152152
@dataclass
153153
class ReportsState:
154-
epoch_supervised_losses: List[float]
155-
epoch_reinforced_scores: List[float]
156-
epoch_reward_model_losses: List[float]
157-
epoch_reward_model_scores: List[float]
154+
step_supervised_losses: List[float]
155+
step_reinforced_scores: List[float]
156+
step_reward_model_losses: List[float]
157+
step_reward_model_scores: List[float]
158158

159159

160160
@dataclass
@@ -167,7 +167,7 @@ class State:
167167
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
168168
batch_size: int
169169
sample_size: int
170-
epoch: int
170+
step: int
171171
metrics: MetricsState
172172
reports: ReportsState
173173
coroutine_queue: Queue[Coroutine]
@@ -357,14 +357,14 @@ async def _fit_supervised(
357357
state.language_model.model,
358358
sequences,
359359
)
360-
state.reports.epoch_supervised_losses.append(loss)
360+
state.reports.step_supervised_losses.append(loss)
361361
loss = await background(
362362
self._fit_reward_model_batch,
363363
state.reward_model.model,
364364
sequences,
365365
scores,
366366
)
367-
state.reports.epoch_reward_model_losses.append(loss)
367+
state.reports.step_reward_model_losses.append(loss)
368368

369369
async def fit_posts(
370370
self, posts: AsyncIterable[Tuple[Dict[str, Any], float]]
@@ -409,7 +409,7 @@ async def _fit_with_reward_model(self) -> None:
409409
sequences,
410410
logprobs,
411411
)
412-
state.reports.epoch_reward_model_scores.append(score)
412+
state.reports.step_reward_model_scores.append(score)
413413

414414
async def _fit_reinforced(
415415
self,
@@ -429,8 +429,8 @@ async def _fit_reinforced(
429429
sequences,
430430
scores,
431431
)
432-
state.reports.epoch_reward_model_losses.append(loss)
433-
state.reports.epoch_reinforced_scores.append(
432+
state.reports.step_reward_model_losses.append(loss)
433+
state.reports.step_reinforced_scores.append(
434434
scores.mean().item()
435435
)
436436

@@ -448,19 +448,19 @@ async def get_results():
448448
await self._fit_reinforced(get_results())
449449

450450
@staticmethod
451-
async def _report_mean_from_epoch(
452-
metric: Metric, epoch: int, label: str, values: Iterable[float]
451+
async def _report_mean_from_step(
452+
metric: Metric, step: int, label: str, values: Iterable[float]
453453
) -> None:
454454
values = list(values)
455455
if values:
456-
await metric.report({"epoch": epoch, label: np.mean(values)})
456+
await metric.report({"step": step, label: np.mean(values)})
457457

458458
@staticmethod
459459
async def _reset_reports(state: State) -> None:
460-
state.reports.epoch_supervised_losses = []
461-
state.reports.epoch_reinforced_scores = []
462-
state.reports.epoch_reward_model_losses = []
463-
state.reports.epoch_reward_model_scores = []
460+
state.reports.step_supervised_losses = []
461+
state.reports.step_reinforced_scores = []
462+
state.reports.step_reward_model_losses = []
463+
state.reports.step_reward_model_scores = []
464464

465465
async def step(self) -> None:
466466
async with self.state.write_lock() as state:
@@ -470,29 +470,29 @@ async def step(self) -> None:
470470
await state.reward_model.optimizer.step()
471471
if state.reward_model.scheduler is not None:
472472
await state.reward_model.scheduler.step()
473-
await self._report_mean_from_epoch(
473+
await self._report_mean_from_step(
474474
state.metrics.supervised_loss_metric,
475-
state.epoch,
475+
state.step,
476476
"loss",
477-
state.reports.epoch_supervised_losses,
477+
state.reports.step_supervised_losses,
478478
)
479-
await self._report_mean_from_epoch(
479+
await self._report_mean_from_step(
480480
state.metrics.reinforced_score_metric,
481-
state.epoch,
481+
state.step,
482482
"score",
483-
state.reports.epoch_reinforced_scores,
483+
state.reports.step_reinforced_scores,
484484
)
485-
await self._report_mean_from_epoch(
485+
await self._report_mean_from_step(
486486
state.metrics.reward_model_loss_metric,
487-
state.epoch,
487+
state.step,
488488
"loss",
489-
state.reports.epoch_reward_model_losses,
489+
state.reports.step_reward_model_losses,
490490
)
491-
await self._report_mean_from_epoch(
491+
await self._report_mean_from_step(
492492
state.metrics.reward_model_score_metric,
493-
state.epoch,
493+
state.step,
494494
"score",
495-
state.reports.epoch_reward_model_scores,
495+
state.reports.step_reward_model_scores,
496496
)
497497
await self._reset_reports(state)
498-
state.epoch += 1
498+
state.step += 1

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[tool.poetry]
55
name = "kilroy-module-pytorch-py-sdk"
6-
version = "0.6.0"
6+
version = "0.6.1"
77
description = "SDK for kilroy modules using PyTorch 🧰"
88
readme = "kilroy_module_pytorch_py_sdk/README.md"
99
authors = ["kilroy <[email protected]>"]

0 commit comments

Comments
 (0)