Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 4813ba7

Browse files
authored
Added utilities (#3)
1 parent d6ffc87 commit 4813ba7

File tree

9 files changed

+895
-3
lines changed

9 files changed

+895
-3
lines changed

kilroy_module_pytorch_py_sdk/poetry.lock

+215
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kilroy_module_pytorch_py_sdk/pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]
@@ -11,6 +11,8 @@ documentation = "https://kilroybot.github.io/kilroy-module-pytorch-py-sdk"
1111

1212
[tool.poetry.dependencies]
1313
python = "^3.9"
14+
torch = "~1"
15+
diskcache = "~5"
1416

1517
# dev
1618

Original file line numberDiff line numberDiff line change
@@ -1 +1,40 @@
1-
from kilroy_module_pytorch_py_sdk.resources import *
1+
from kilroy_module_pytorch_py_sdk.resources import (
2+
resource,
3+
resource_bytes,
4+
resource_text,
5+
)
6+
from kilroy_module_pytorch_py_sdk.buffer import TensorBuffer
7+
from kilroy_module_pytorch_py_sdk.cache import (
8+
Cache,
9+
Disk,
10+
CacheLike,
11+
TensorDisk,
12+
TensorCache,
13+
ProgressiveCache,
14+
ProgressiveTensorCache,
15+
SelfCleaningCache,
16+
)
17+
from kilroy_module_pytorch_py_sdk.generator import (
18+
GenerationResult,
19+
SequenceGenerator,
20+
)
21+
from kilroy_module_pytorch_py_sdk.samplers import (
22+
SampleResult,
23+
Sampler,
24+
ProportionalCategoricalSampler,
25+
TopKCategoricalSampler,
26+
NucleusCategoricalSampler,
27+
WithEpsilon,
28+
EpsilonProportionalCategoricalSampler,
29+
EpsilonTopKCategoricalSampler,
30+
EpsilonNucleusCategoricalSampler,
31+
)
32+
from kilroy_module_pytorch_py_sdk.utils import (
33+
pad,
34+
unpad,
35+
pack_padded,
36+
pack_list,
37+
unpack_to_padded,
38+
unpack_to_list,
39+
squash_packed,
40+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Generic, MutableMapping, TypeVar
2+
from uuid import uuid4
3+
4+
from torch import Tensor, nn
5+
from torch.autograd.graph import saved_tensors_hooks
6+
7+
from kilroy_module_pytorch_py_sdk.utils import SelfCleaningKey
8+
9+
StoreType = TypeVar("StoreType", bound=MutableMapping)
10+
ModelType = TypeVar("ModelType", bound=nn.Module)
11+
12+
13+
class TensorBuffer(Generic[StoreType]):
14+
_store: StoreType
15+
16+
def __init__(self, store: StoreType) -> None:
17+
self._store = store
18+
self._hooks_ctx = saved_tensors_hooks(self.pack_hook, self.unpack_hook)
19+
20+
@property
21+
def store(self) -> StoreType:
22+
return self._store
23+
24+
def pack_hook(self, x: Tensor) -> SelfCleaningKey:
25+
# TODO: remove SelfCleaningKey and just return UUID int
26+
# on unpack_hook add this int to some set
27+
# and make flush() method to delete all keys from this set
28+
# this method should be called after backwards() is called
29+
key = SelfCleaningKey(uuid4().hex, self._store)
30+
self._store[key.key] = x.clone()
31+
return key
32+
33+
def unpack_hook(self, x: SelfCleaningKey) -> Tensor:
34+
return self._store[x.key]
35+
36+
def __enter__(self) -> "TensorBuffer":
37+
self._hooks_ctx.__enter__()
38+
return self
39+
40+
def __exit__(self, exc_type, exc_value, traceback) -> None:
41+
return self._hooks_ctx.__exit__(exc_type, exc_value, traceback)

0 commit comments

Comments
 (0)