Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit e4bf0c8

Browse files
authored
Added freeze and fixed imports (#4)
Added freeze and fixed imports
1 parent 4813ba7 commit e4bf0c8

File tree

4 files changed

+21
-3
lines changed

4 files changed

+21
-3
lines changed

kilroy_module_pytorch_py_sdk/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "kilroy-module-pytorch-py-sdk"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "SDK for kilroy modules using PyTorch 🧰"
55
readme = "README.md"
66
authors = ["kilroy <[email protected]>"]

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030
EpsilonNucleusCategoricalSampler,
3131
)
3232
from kilroy_module_pytorch_py_sdk.utils import (
33+
truncate_first_element,
34+
truncate_last_element,
3335
pad,
3436
unpad,
3537
pack_padded,
3638
pack_list,
3739
unpack_to_padded,
3840
unpack_to_list,
3941
squash_packed,
42+
freeze,
4043
)

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import weakref
2+
from contextlib import contextmanager
23
from typing import (
34
Any,
45
Generic,
@@ -11,7 +12,7 @@
1112
)
1213

1314
import torch
14-
from torch import Tensor
15+
from torch import Tensor, nn
1516
from torch.nn.utils.rnn import (
1617
PackedSequence,
1718
pack_padded_sequence,
@@ -88,3 +89,17 @@ def squash_packed(x, fn):
8889
return PackedSequence(
8990
fn(x.data), x.batch_sizes, x.sorted_indices, x.unsorted_indices
9091
)
92+
93+
94+
@contextmanager
95+
def freeze(model: nn.Module) -> nn.Module:
96+
original_state = {}
97+
98+
for name, param in model.named_parameters():
99+
original_state[name] = param.requires_grad
100+
param.requires_grad = False
101+
102+
yield model
103+
104+
for name, param in model.named_parameters():
105+
param.requires_grad = original_state[name]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[tool.poetry]
55
name = "kilroy-module-pytorch-py-sdk"
6-
version = "0.2.0"
6+
version = "0.2.1"
77
description = "SDK for kilroy modules using PyTorch 🧰"
88
readme = "kilroy_module_pytorch_py_sdk/README.md"
99
authors = ["kilroy <[email protected]>"]

0 commit comments

Comments
 (0)