Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit aadcec0

Browse files
authored
Added LR schedulers (#16)
1 parent f3d31c9 commit aadcec0

File tree

15 files changed

+853
-21
lines changed

15 files changed

+853
-21
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,65 @@
1-
from kilroy_module_pytorch_py_sdk.resources import (
2-
resource,
3-
resource_bytes,
4-
resource_text,
5-
)
61
from kilroy_module_server_py_sdk import *
7-
from kilroy_module_pytorch_py_sdk.generator import Generator, GenerationResult
2+
from kilroy_module_pytorch_py_sdk.codec import Codec
3+
from kilroy_module_pytorch_py_sdk.generator import GenerationResult, Generator
4+
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
85
from kilroy_module_pytorch_py_sdk.modules.basic import (
96
BasicModule,
7+
MetricsState as BasicModuleMetricsState,
8+
ReportsState as BasicModuleReportsState,
109
State as BasicModuleState,
1110
)
1211
from kilroy_module_pytorch_py_sdk.modules.reward import (
12+
LanguageModelState as RewardModelModuleLanguageModelState,
13+
MetricsState as RewardModelModuleMetricsState,
14+
ReportsState as RewardModelModuleReportsState,
1315
RewardModelModule,
16+
RewardModelState as RewardModelModuleRewardModelState,
1417
State as RewardModelModuleState,
1518
)
1619
from kilroy_module_pytorch_py_sdk.optimizers import (
17-
Optimizer,
1820
AdamOptimizer,
19-
SGDOptimizer,
21+
Optimizer,
2022
RMSPropOptimizer,
23+
SGDOptimizer,
24+
)
25+
from kilroy_module_pytorch_py_sdk.resources import (
26+
resource,
27+
resource_bytes,
28+
resource_text,
2129
)
2230
from kilroy_module_pytorch_py_sdk.samplers import (
23-
Sampler,
24-
ProportionalSampler,
31+
EpsilonNucleusSampler,
2532
EpsilonProportionalSampler,
26-
TopKSampler,
2733
EpsilonTopKSampler,
2834
NucleusSampler,
29-
EpsilonNucleusSampler,
35+
ProportionalSampler,
36+
Sampler,
37+
TopKSampler,
38+
)
39+
from kilroy_module_pytorch_py_sdk.schedulers import (
40+
ConstantScheduler,
41+
CosineAnnealingScheduler,
42+
CyclicScheduler,
43+
ExponentialScheduler,
44+
LinearScheduler,
45+
MultiStepScheduler,
46+
OneCycleScheduler,
47+
ReduceOnPlateauScheduler,
48+
Scheduler,
49+
StepScheduler,
50+
WarmRestartsScheduler,
3051
)
31-
from kilroy_module_pytorch_py_sdk.codec import Codec
32-
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
3352
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
3453
from kilroy_module_pytorch_py_sdk.utils import (
54+
freeze,
55+
pack_list,
56+
pack_padded,
57+
pad,
3558
slice_sequences,
59+
squash_packed,
3660
truncate_first_element,
3761
truncate_last_element,
38-
pad,
39-
unpad,
40-
pack_padded,
41-
pack_list,
42-
unpack_to_padded,
4362
unpack_to_list,
44-
squash_packed,
45-
freeze,
63+
unpack_to_padded,
64+
unpad,
4665
)

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/basic.py

+23
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
from aiostream import stream
1818
from kilroy_module_server_py_sdk import (
19+
CategorizableBasedOptionalParameter,
1920
CategorizableBasedParameter,
2021
JSONSchema,
2122
Metric,
@@ -33,6 +34,7 @@
3334
from kilroy_module_pytorch_py_sdk.generator import Generator
3435
from kilroy_module_pytorch_py_sdk.models import LanguageModel
3536
from kilroy_module_pytorch_py_sdk.optimizers import Optimizer
37+
from kilroy_module_pytorch_py_sdk.schedulers.base import Scheduler
3638
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
3739
from kilroy_module_pytorch_py_sdk.utils import (
3840
pack_list,
@@ -96,6 +98,8 @@ class State:
9698
tokenizer: Tokenizer
9799
optimizer: Optimizer
98100
optimizers_params: Dict[str, Dict[str, Any]]
101+
scheduler: Optional[Scheduler]
102+
schedulers_params: Dict[str, Dict[str, Any]]
99103
generator: Generator
100104
codec: Codec
101105
results_cache: Dict[UUID, Tuple[Tensor, Tensor]]
@@ -112,6 +116,22 @@ async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
112116
**state.optimizers_params.get(category, {}),
113117
}
114118

119+
async def _set_categorizable(self, state: State, value: Optimizer) -> None:
120+
await super()._set_categorizable(state, value)
121+
if state.scheduler is not None:
122+
optimizer = await value.get()
123+
await state.scheduler.change_optimizer(optimizer)
124+
125+
126+
class SchedulerParameter(
127+
CategorizableBasedOptionalParameter[State, Scheduler]
128+
):
129+
async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
130+
return {
131+
"optimizer": await state.optimizer.get(),
132+
**state.schedulers_params.get(category, {}),
133+
}
134+
115135

116136
class GeneratorParameter(NestedParameter[State, Generator]):
117137
pass
@@ -136,6 +156,7 @@ def post_schema(cls) -> JSONSchema:
136156
def parameters(cls) -> Set[Parameter]:
137157
return {
138158
OptimizerParameter(),
159+
SchedulerParameter(),
139160
GeneratorParameter(),
140161
CodecParameter(),
141162
BatchSizeParameter(),
@@ -240,6 +261,8 @@ async def _reset_reports(state: State) -> None:
240261
async def step(self) -> None:
241262
async with self.state.write_lock() as state:
242263
await state.optimizer.step()
264+
if state.scheduler is not None:
265+
await state.scheduler.step()
243266
await self._report_mean_from_epoch(
244267
state.metrics.supervised_loss_metric,
245268
state.epoch,

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+44
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from aiostream import stream
2020
from aiostream.aiter_utils import aiter, anext
2121
from kilroy_module_server_py_sdk import (
22+
CategorizableBasedOptionalParameter,
2223
CategorizableBasedParameter,
2324
JSONSchema,
2425
Metric,
@@ -37,6 +38,7 @@
3738
from kilroy_module_pytorch_py_sdk.generator import Generator
3839
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
3940
from kilroy_module_pytorch_py_sdk.optimizers import Optimizer
41+
from kilroy_module_pytorch_py_sdk.schedulers.base import Scheduler
4042
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
4143
from kilroy_module_pytorch_py_sdk.utils import (
4244
freeze,
@@ -125,6 +127,8 @@ class LanguageModelState:
125127
tokenizer: Tokenizer
126128
optimizer: Optimizer
127129
optimizers_params: Dict[str, Dict[str, Any]]
130+
scheduler: Optional[Scheduler]
131+
schedulers_params: Dict[str, Dict[str, Any]]
128132

129133

130134
@dataclass
@@ -133,6 +137,8 @@ class RewardModelState:
133137
tokenizer: Tokenizer
134138
optimizer: Optimizer
135139
optimizers_params: Dict[str, Dict[str, Any]]
140+
scheduler: Optional[Scheduler]
141+
schedulers_params: Dict[str, Dict[str, Any]]
136142

137143

138144
@dataclass
@@ -177,6 +183,22 @@ async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
177183
**state.language_model.optimizers_params.get(category, {}),
178184
}
179185

186+
async def _set_categorizable(self, state: State, value: Optimizer) -> None:
187+
await super()._set_categorizable(state, value)
188+
if state.language_model.scheduler is not None:
189+
optimizer = await value.get()
190+
await state.language_model.scheduler.change_optimizer(optimizer)
191+
192+
193+
class LanguageModelSchedulerParameter(
194+
CategorizableBasedOptionalParameter[State, Scheduler]
195+
):
196+
async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
197+
return {
198+
"optimizer": await state.language_model.optimizer.get(),
199+
**state.language_model.schedulers_params.get(category, {}),
200+
}
201+
180202

181203
class RewardModelOptimizerParameter(
182204
CategorizableBasedParameter[State, Optimizer]
@@ -187,6 +209,22 @@ async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
187209
**state.reward_model.optimizers_params.get(category, {}),
188210
}
189211

212+
async def _set_categorizable(self, state: State, value: Optimizer) -> None:
213+
await super()._set_categorizable(state, value)
214+
if state.reward_model.scheduler is not None:
215+
optimizer = await value.get()
216+
await state.reward_model.scheduler.change_optimizer(optimizer)
217+
218+
219+
class RewardModelSchedulerParameter(
220+
CategorizableBasedOptionalParameter[State, Scheduler]
221+
):
222+
async def _get_params(self, state: State, category: str) -> Dict[str, Any]:
223+
return {
224+
"optimizer": await state.reward_model.optimizer.get(),
225+
**state.reward_model.schedulers_params.get(category, {}),
226+
}
227+
190228

191229
class FrontendGeneratorParameter(NestedParameter[State, Generator]):
192230
pass
@@ -221,7 +259,9 @@ def post_schema(cls) -> JSONSchema:
221259
def parameters(cls) -> Set[Parameter]:
222260
return {
223261
LanguageModelOptimizerParameter(),
262+
LanguageModelSchedulerParameter(),
224263
RewardModelOptimizerParameter(),
264+
RewardModelSchedulerParameter(),
225265
FrontendGeneratorParameter(),
226266
BackendGeneratorParameter(),
227267
CodecParameter(),
@@ -425,7 +465,11 @@ async def _reset_reports(state: State) -> None:
425465
async def step(self) -> None:
426466
async with self.state.write_lock() as state:
427467
await state.language_model.optimizer.step()
468+
if state.language_model.scheduler is not None:
469+
await state.language_model.scheduler.step()
428470
await state.reward_model.optimizer.step()
471+
if state.reward_model.scheduler is not None:
472+
await state.reward_model.scheduler.step()
429473
await self._report_mean_from_epoch(
430474
state.metrics.supervised_loss_metric,
431475
state.epoch,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from kilroy_module_pytorch_py_sdk.schedulers.base import Scheduler
2+
from kilroy_module_pytorch_py_sdk.schedulers.constant import ConstantScheduler
3+
from kilroy_module_pytorch_py_sdk.schedulers.cosine import (
4+
CosineAnnealingScheduler,
5+
)
6+
from kilroy_module_pytorch_py_sdk.schedulers.cyclic import CyclicScheduler
7+
from kilroy_module_pytorch_py_sdk.schedulers.exponential import (
8+
ExponentialScheduler,
9+
)
10+
from kilroy_module_pytorch_py_sdk.schedulers.linear import LinearScheduler
11+
from kilroy_module_pytorch_py_sdk.schedulers.multistep import (
12+
MultiStepScheduler,
13+
)
14+
from kilroy_module_pytorch_py_sdk.schedulers.onecycle import OneCycleScheduler
15+
from kilroy_module_pytorch_py_sdk.schedulers.plateau import (
16+
ReduceOnPlateauScheduler,
17+
)
18+
from kilroy_module_pytorch_py_sdk.schedulers.step import StepScheduler
19+
from kilroy_module_pytorch_py_sdk.schedulers.warmrestarts import (
20+
WarmRestartsScheduler,
21+
)

0 commit comments

Comments
 (0)