11import math
22import dataclasses
33from pathlib import Path
4+ import json
45from shutil import rmtree
56from copy import deepcopy
67from typing import Tuple , List , Optional , Callable , Literal , Generator , Union
@@ -1443,6 +1444,22 @@ def _get_batch(perm, x, y):
14431444 end = start + batch_size
14441445
14451446
1447+ def save (filename : Optional [str ] = None , * , hyperparams : ConfigDict , model : TransformerFlow ) -> None :
1448+ filename = default (filename , Path .cwd () / "transformer_flow.eqx" )
1449+ with open (filename , "wb" ) as f :
1450+ hyperparam_str = json .dumps (hyperparams )
1451+ f .write ((hyperparam_str + "\n " ).encode ())
1452+ eqx .tree_serialise_leaves (f , model )
1453+
1454+
1455+ def load (filename : Optional [str ] = None , * , hyperparams : ConfigDict ) -> TransformerFlow :
1456+ filename = default (filename , Path .cwd () / "transformer_flow.eqx" )
1457+ with open (filename , "rb" ) as f :
1458+ hyperparams = json .loads (f .readline ().decode ())
1459+ model = eqx .nn .make_with_state (key = jr .key (0 ), ** hyperparams )[0 ]
1460+ return eqx .tree_deserialise_leaves (f , model )
1461+
1462+
14461463@typecheck
14471464def get_data (
14481465 key : PRNGKeyArray ,
@@ -1634,7 +1651,8 @@ def train(
16341651 cmap : Optional [str ] = None ,
16351652 # Sharding: data and model
16361653 sharding : Optional [NamedSharding ] = None ,
1637- replicated_sharding : Optional [NamedSharding ] = None
1654+ replicated_sharding : Optional [NamedSharding ] = None ,
1655+ save_fn : Callable [[Optional [str ], TransformerFlow ], None ]
16381656) -> TransformerFlow :
16391657
16401658 print ("n_params={:.3E}" .format (count_parameters (model )))
@@ -1876,6 +1894,8 @@ def filter_spikes(l: list, loss_max: float = 10.0) -> list[float]:
18761894 plt .savefig (imgs_dir / "losses.png" , bbox_inches = "tight" )
18771895 plt .close ()
18781896
1897+ save_fn (model = ema_model if use_ema else model )
1898+
18791899 return model
18801900
18811901
@@ -1963,6 +1983,8 @@ def get_config(dataset_name: str) -> ConfigDict:
19631983
19641984 dataset_name = "MNIST"
19651985
1986+ reload_model = False
1987+
19661988 config = get_config (dataset_name )
19671989
19681990 imgs_dir = clear_and_get_results_dir (dataset_name )
@@ -1982,6 +2004,11 @@ def get_config(dataset_name: str) -> ConfigDict:
19822004 else :
19832005 policy = None
19842006
2007+ if reload_model :
2008+ model = load (hyperparams = config .model )
2009+
2010+ save_fn = partial (save , hyperparams = config .model )
2011+
19852012 model = train (
19862013 key_train ,
19872014 # Model
@@ -2017,5 +2044,6 @@ def get_config(dataset_name: str) -> ConfigDict:
20172044 policy = policy ,
20182045 get_state_fn = get_state_fn ,
20192046 sharding = sharding ,
2020- replicated_sharding = replicated_sharding
2047+ replicated_sharding = replicated_sharding ,
2048+ save_fn = save_fn
20212049 )
0 commit comments