Skip to content

Commit 8a1fcca

Browse files
committed
add saving
1 parent d9b9169 commit 8a1fcca

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ Features:
1010
- array-typed to-the-teeth for dependable execution with `jaxtyping` and `beartype`.
1111

1212
To implement:
13-
- [ ] Guidance
13+
- [x] Guidance
1414
- [x] Denoising
1515
- [x] Mixed precision
1616
- [x] EMA
1717
- [x] AdaLayerNorm
1818
- [x] Class embedding
19-
- [ ] Hyperparameter/model saving
19+
- [x] Hyperparameter/model saving
2020
- [x] Uniform and Gaussian noise for dequantisation
2121

2222
#### Usage

transformer_flow.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import dataclasses
33
from pathlib import Path
4+
import json
45
from shutil import rmtree
56
from copy import deepcopy
67
from typing import Tuple, List, Optional, Callable, Literal, Generator, Union
@@ -1443,6 +1444,22 @@ def _get_batch(perm, x, y):
14431444
end = start + batch_size
14441445

14451446

1447+
def save(filename: Optional[str] = None, *, hyperparams: ConfigDict, model: TransformerFlow) -> None:
1448+
filename = default(filename, Path.cwd() / "transformer_flow.eqx")
1449+
with open(filename, "wb") as f:
1450+
hyperparam_str = json.dumps(hyperparams)
1451+
f.write((hyperparam_str + "\n").encode())
1452+
eqx.tree_serialise_leaves(f, model)
1453+
1454+
1455+
def load(filename: Optional[str] = None, *, hyperparams: ConfigDict) -> TransformerFlow:
1456+
filename = default(filename, Path.cwd() / "transformer_flow.eqx")
1457+
with open(filename, "rb") as f:
1458+
hyperparams = json.loads(f.readline().decode())
1459+
model = eqx.nn.make_with_state(key=jr.key(0), **hyperparams)[0]
1460+
return eqx.tree_deserialise_leaves(f, model)
1461+
1462+
14461463
@typecheck
14471464
def get_data(
14481465
key: PRNGKeyArray,
@@ -1634,7 +1651,8 @@ def train(
16341651
cmap: Optional[str] = None,
16351652
# Sharding: data and model
16361653
sharding: Optional[NamedSharding] = None,
1637-
replicated_sharding: Optional[NamedSharding] = None
1654+
replicated_sharding: Optional[NamedSharding] = None,
1655+
save_fn: Callable[[Optional[str], TransformerFlow], None]
16381656
) -> TransformerFlow:
16391657

16401658
print("n_params={:.3E}".format(count_parameters(model)))
@@ -1876,6 +1894,8 @@ def filter_spikes(l: list, loss_max: float = 10.0) -> list[float]:
18761894
plt.savefig(imgs_dir / "losses.png", bbox_inches="tight")
18771895
plt.close()
18781896

1897+
save_fn(model=ema_model if use_ema else model)
1898+
18791899
return model
18801900

18811901

@@ -1963,6 +1983,8 @@ def get_config(dataset_name: str) -> ConfigDict:
19631983

19641984
dataset_name = "MNIST"
19651985

1986+
reload_model = False
1987+
19661988
config = get_config(dataset_name)
19671989

19681990
imgs_dir = clear_and_get_results_dir(dataset_name)
@@ -1982,6 +2004,11 @@ def get_config(dataset_name: str) -> ConfigDict:
19822004
else:
19832005
policy = None
19842006

2007+
if reload_model:
2008+
model = load(hyperparams=config.model)
2009+
2010+
save_fn = partial(save, hyperparams=config.model)
2011+
19852012
model = train(
19862013
key_train,
19872014
# Model
@@ -2017,5 +2044,6 @@ def get_config(dataset_name: str) -> ConfigDict:
20172044
policy=policy,
20182045
get_state_fn=get_state_fn,
20192046
sharding=sharding,
2020-
replicated_sharding=replicated_sharding
2047+
replicated_sharding=replicated_sharding,
2048+
save_fn=save_fn
20212049
)

0 commit comments

Comments
 (0)