Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit dbf66a0

Browse files
authored
Update so huge that I won't even bother and will commit all at once (#30)
1 parent 1e6bdf1 commit dbf66a0

File tree

172 files changed

+11485
-2069
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

172 files changed

+11485
-2069
lines changed

kilroy_module_pytorch_py_sdk/poetry.lock

+401-308
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kilroy_module_pytorch_py_sdk/pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.7.2"
3+
version = "0.8.0"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]
@@ -13,7 +13,7 @@ documentation = "https://kilroybot.github.io/kilroy-module-pytorch-py-sdk"
1313
python = "^3.10"
1414
torch = "~1"
1515
numpy = "~1"
16-
kilroy-module-server-py-sdk = "~0.9"
16+
kilroy-module-server-py-sdk = "~0.10"
1717

1818
[tool.poetry.group.poe.dependencies]
1919
poethepoet = "^0.16"
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
from kilroy_module_server_py_sdk import *
2-
from kilroy_module_pytorch_py_sdk.codec import Codec
3-
from kilroy_module_pytorch_py_sdk.generator import GenerationResult, Generator
4-
from kilroy_module_pytorch_py_sdk.models import LanguageModel, RewardModel
5-
from kilroy_module_pytorch_py_sdk.modules.basic import (
6-
BasicModule,
7-
MetricsState as BasicModuleMetricsState,
8-
ReportsState as BasicModuleReportsState,
9-
State as BasicModuleState,
10-
)
11-
from kilroy_module_pytorch_py_sdk.modules.reward import (
12-
LanguageModelState as RewardModelModuleLanguageModelState,
13-
MetricsState as RewardModelModuleMetricsState,
14-
ReportsState as RewardModelModuleReportsState,
15-
RewardModelModule,
16-
RewardModelState as RewardModelModuleRewardModelState,
17-
State as RewardModelModuleState,
2+
from kilroy_module_pytorch_py_sdk.generator import Generator
3+
from kilroy_module_pytorch_py_sdk.models.abc import SequentialModel
4+
from kilroy_module_pytorch_py_sdk.models.loader import ModelLoader
5+
from kilroy_module_pytorch_py_sdk.models.registry import ModelsRegistry
6+
from kilroy_module_pytorch_py_sdk.module.module import PytorchModule
7+
from kilroy_module_pytorch_py_sdk.trainers import (
8+
Trainer,
9+
VanillaTrainer,
1810
)
11+
from kilroy_module_pytorch_py_sdk.metrics import LineMetric
1912
from kilroy_module_pytorch_py_sdk.optimizers import (
2013
AdamOptimizer,
2114
Optimizer,
@@ -27,15 +20,6 @@
2720
resource_bytes,
2821
resource_text,
2922
)
30-
from kilroy_module_pytorch_py_sdk.samplers import (
31-
EpsilonNucleusSampler,
32-
EpsilonProportionalSampler,
33-
EpsilonTopKSampler,
34-
NucleusSampler,
35-
ProportionalSampler,
36-
Sampler,
37-
TopKSampler,
38-
)
3923
from kilroy_module_pytorch_py_sdk.schedulers import (
4024
ConstantScheduler,
4125
CosineAnnealingScheduler,
@@ -51,7 +35,6 @@
5135
)
5236
from kilroy_module_pytorch_py_sdk.tokenizer import Tokenizer
5337
from kilroy_module_pytorch_py_sdk.utils import (
54-
freeze,
5538
pack_list,
5639
pack_padded,
5740
pad,
@@ -62,4 +45,5 @@
6245
unpack_to_list,
6346
unpack_to_padded,
6447
unpad,
48+
freeze,
6549
)

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/codec.py

-61
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Dict, Any, Callable, Awaitable
2+
3+
import torch
4+
from pydantic import Field
5+
from torch.nn.utils.rnn import PackedSequence
6+
7+
from kilroy_module_pytorch_py_sdk.utils import unpack_to_list, pack_list
8+
from kilroy_module_server_py_sdk import SerializableState
9+
from kilroy_server_py_utils import Configurable, Parameter, classproperty
10+
11+
12+
class State(SerializableState):
13+
gamma: float = 0.99
14+
lambda_: float = Field(0.95, alias="lambda")
15+
16+
17+
class GeneralizedAdvantageEstimator(Configurable[State]):
18+
class GammaParameter(Parameter[State, float]):
19+
# noinspection PyMethodParameters
20+
@classproperty
21+
def schema(cls) -> Dict[str, Any]:
22+
return {
23+
"type": "number",
24+
"minimum": 0,
25+
"maximum": 1,
26+
"default": 0.99,
27+
"title": cls.pretty_name,
28+
}
29+
30+
class LambdaParameter(Parameter[State, float]):
31+
# noinspection PyMethodParameters
32+
@classproperty
33+
def schema(cls) -> Dict[str, Any]:
34+
return {
35+
"type": "number",
36+
"minimum": 0,
37+
"maximum": 1,
38+
"default": 0.95,
39+
"title": cls.pretty_name,
40+
}
41+
42+
@classmethod
43+
async def _get(cls, state: State) -> float:
44+
return state.lambda_
45+
46+
@classmethod
47+
async def _set(
48+
cls, state: State, value: float
49+
) -> Callable[[], Awaitable]:
50+
original_value = state.lambda_
51+
52+
async def undo():
53+
state.lambda_ = original_value
54+
55+
state.lambda_ = value
56+
return undo
57+
58+
async def calculate(
59+
self, rewards: PackedSequence, values: PackedSequence
60+
) -> PackedSequence:
61+
async with self.state.read_lock() as state:
62+
gamma = state.gamma
63+
lambda_ = state.lambda_
64+
65+
batch_rewards = unpack_to_list(rewards)
66+
batch_values = unpack_to_list(values)
67+
68+
batch_advantages = []
69+
70+
for rewards, values in zip(batch_rewards, batch_values):
71+
advantages = []
72+
advantage = 0
73+
74+
for i in reversed(range(len(rewards))):
75+
reward = rewards[i]
76+
value = values[i]
77+
next_value = (
78+
torch.zeros(1) if i == len(rewards) - 1 else values[i + 1]
79+
)
80+
81+
delta = reward + gamma * next_value - value
82+
advantage = delta + gamma * lambda_ * advantage
83+
advantages.append(advantage)
84+
85+
advantages = torch.stack(advantages[::-1])
86+
batch_advantages.append(advantages)
87+
88+
return pack_list(batch_advantages)
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from kilroy_module_pytorch_py_sdk.generator.generator import Generator
2-
from kilroy_module_pytorch_py_sdk.generator.utils import GenerationResult

0 commit comments

Comments
 (0)