Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit bde77ce

Browse files
authored
Added support for saving state (#12)
1 parent dcf517a commit bde77ce

File tree

9 files changed

+104
-39
lines changed

9 files changed

+104
-39
lines changed

kilroy_module_pytorch_py_sdk/poetry.lock

+10-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kilroy_module_pytorch_py_sdk/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.4.0"
3+
version = "0.5.0"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/generator/generator.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import json
12
import random
23
from dataclasses import dataclass
4+
from pathlib import Path
35
from typing import Any, AsyncIterable, Dict, Iterable, List, Set
46

57
from kilroy_module_server_py_sdk import (
68
CategorizableBasedParameter,
79
Configurable,
810
Parameter,
11+
Savable,
912
SerializableModel,
1013
classproperty,
1114
)
@@ -70,21 +73,52 @@ def parameters(cls) -> Set[Parameter]:
7073

7174
async def build_default_state(self) -> State:
7275
params = Params(**self._kwargs)
73-
sampler_cls = Sampler.for_category(params.sampler_type)
74-
sampler_params = params.samplers_params.get(params.sampler_type, {})
75-
if issubclass(sampler_cls, Configurable):
76-
sampler = await sampler_cls.build(**sampler_params)
77-
await sampler.init()
78-
else:
79-
sampler = sampler_cls(**sampler_params)
8076
return State(
81-
sampler=sampler,
77+
sampler=await self.build_generic(
78+
Sampler,
79+
category=params.sampler_type,
80+
**params.samplers_params.get(params.sampler_type, {}),
81+
),
8282
samplers_params=params.samplers_params,
8383
contexts=params.contexts,
8484
max_length=params.max_length,
8585
batch_size=params.batch_size,
8686
)
8787

88+
async def save_state(self, state: State, directory: Path) -> None:
89+
state_dict = {
90+
"sampler_type": state.sampler.category,
91+
"samplers_params": state.samplers_params,
92+
"contexts": state.contexts,
93+
"max_length": state.max_length,
94+
"batch_size": state.batch_size,
95+
}
96+
if isinstance(state.sampler, Savable):
97+
await state.sampler.save(directory / "sampler")
98+
with open(directory / "state.json", "w") as f:
99+
json.dump(state_dict, f)
100+
101+
async def load_saved_state(self, directory: Path) -> State:
102+
with open(directory / "state.json", "r") as f:
103+
state_dict = json.load(f)
104+
sampler_type = state_dict["sampler_type"]
105+
sampler_kwargs = {
106+
**self._kwargs.get("samplers_params", {}).get(sampler_type, {}),
107+
**state_dict["samplers_params"].get(sampler_type, {}),
108+
}
109+
return State(
110+
sampler=await self.load_generic(
111+
directory / "sampler",
112+
Sampler,
113+
category=state_dict["sampler_type"],
114+
**sampler_kwargs,
115+
),
116+
samplers_params=state_dict["samplers_params"],
117+
contexts=state_dict["contexts"],
118+
max_length=state_dict["max_length"],
119+
batch_size=state_dict["batch_size"],
120+
)
121+
88122
async def cleanup(self) -> None:
89123
async with self.state.write_lock() as state:
90124
if isinstance(state.sampler, Configurable):

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/modules/reward.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
11
from abc import ABC
22
from asyncio import Queue, Task
33
from dataclasses import dataclass
4-
from typing import (
5-
Any,
6-
AsyncIterable,
7-
Coroutine,
8-
Dict,
9-
Generator,
10-
List,
11-
Set,
12-
Tuple,
13-
)
4+
from typing import Any, AsyncIterable, Coroutine, Dict, List, Set, Tuple
145
from uuid import UUID, uuid4
156

167
import numpy as np
@@ -32,6 +23,7 @@
3223
from torch.nn import MSELoss, NLLLoss
3324
from torch.nn.utils.rnn import PackedSequence
3425

26+
from kilroy_module_pytorch_py_sdk import Generator
3527
from kilroy_module_pytorch_py_sdk.codec import Codec
3628
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
3729
from kilroy_module_pytorch_py_sdk.optimizers import Optimizer

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/optimizers/adam.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23
from typing import Any, Dict, List
34

5+
import torch
46
from kilroy_module_server_py_sdk import (
57
Configurable,
68
SerializableModel,
@@ -66,8 +68,21 @@ def schema(cls) -> Dict[str, Any]:
6668
return {"type": "number", "minimum": 0}
6769

6870
async def build_default_state(self) -> State:
71+
model_params = self._kwargs.pop("parameters")
6972
user_params = Params(**self._kwargs)
70-
return State(optimizer=Adam(self._params, **user_params.dict()))
73+
return State(optimizer=Adam(model_params, **user_params.dict()))
74+
75+
@classmethod
76+
async def save_state(cls, state: State, directory: Path) -> None:
77+
with open(directory / "optimizer.pt", "wb") as f:
78+
await background(torch.save, state.optimizer.state_dict(), f)
79+
80+
async def load_saved_state(self, directory: Path) -> State:
81+
with open(directory / "optimizer.pt", "rb") as f:
82+
state_dict = await background(torch.load, f)
83+
optimizer = Adam(self._kwargs.pop("parameters"))
84+
optimizer.load_state_dict(state_dict)
85+
return State(optimizer=optimizer)
7186

7287
async def step(self) -> None:
7388
async with self.state.write_lock() as state:

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/optimizers/base.py

-6
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,6 @@ async def undo() -> None:
5454

5555

5656
class Optimizer(Categorizable, ABC):
57-
def __init__(
58-
self, params: Union[Iterable[Tensor], Iterable[Dict]], *args, **kwargs
59-
) -> None:
60-
super().__init__(*args, **kwargs)
61-
self._params = list(params)
62-
6357
@classproperty
6458
def category(cls) -> str:
6559
name: str = cls.__name__

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/optimizers/rmsprop.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23
from typing import Any, Dict
34

5+
import torch
46
from kilroy_module_server_py_sdk import (
57
Configurable,
68
SerializableModel,
@@ -55,8 +57,21 @@ def schema(cls) -> Dict[str, Any]:
5557
return {"type": "number", "minimum": 0}
5658

5759
async def build_default_state(self) -> State:
60+
model_params = self._kwargs.pop("parameters")
5861
user_params = Params(**self._kwargs)
59-
return State(optimizer=RMSprop(self._params, **user_params.dict()))
62+
return State(optimizer=RMSprop(model_params, **user_params.dict()))
63+
64+
@classmethod
65+
async def save_state(cls, state: State, directory: Path) -> None:
66+
with open(directory / "optimizer.pt", "wb") as f:
67+
await background(torch.save, state.optimizer.state_dict(), f)
68+
69+
async def load_saved_state(self, directory: Path) -> State:
70+
with open(directory / "optimizer.pt", "rb") as f:
71+
state_dict = await background(torch.load, f)
72+
optimizer = RMSprop(self._kwargs.pop("parameters"))
73+
optimizer.load_state_dict(state_dict)
74+
return State(optimizer=optimizer)
6075

6176
async def step(self) -> None:
6277
async with self.state.write_lock() as state:

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/optimizers/sgd.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
2+
from pathlib import Path
23
from typing import Any, Dict
34

5+
import torch
46
from kilroy_module_server_py_sdk import (
57
Configurable,
68
SerializableModel,
@@ -49,8 +51,21 @@ def schema(cls) -> Dict[str, Any]:
4951
return {"type": "number", "minimum": 0}
5052

5153
async def build_default_state(self) -> State:
54+
model_params = self._kwargs.pop("parameters")
5255
user_params = Params(**self._kwargs)
53-
return State(optimizer=SGD(self._params, **user_params.dict()))
56+
return State(optimizer=SGD(model_params, **user_params.dict()))
57+
58+
@classmethod
59+
async def save_state(cls, state: State, directory: Path) -> None:
60+
with open(directory / "optimizer.pt", "wb") as f:
61+
await background(torch.save, state.optimizer.state_dict(), f)
62+
63+
async def load_saved_state(self, directory: Path) -> State:
64+
with open(directory / "optimizer.pt", "rb") as f:
65+
state_dict = await background(torch.load, f)
66+
optimizer = SGD(self._kwargs.pop("parameters"))
67+
optimizer.load_state_dict(state_dict)
68+
return State(optimizer=optimizer)
5469

5570
async def step(self) -> None:
5671
async with self.state.write_lock() as state:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[tool.poetry]
55
name = "kilroy-module-pytorch-py-sdk"
6-
version = "0.4.0"
6+
version = "0.5.0"
77
description = "SDK for kilroy modules using PyTorch 🧰"
88
readme = "kilroy_module_pytorch_py_sdk/README.md"
99
authors = ["kilroy <[email protected]>"]

0 commit comments

Comments
 (0)