2222import matplotlib .pyplot as plt
2323from tqdm .auto import trange
2424
25- from .attention import MultiheadAttention , self_attention
25+ from .attention import MultiheadAttention , self_attention , KQVCacheType
2626
2727
2828if os .getenv ("TYPECHECK" , "" ).lower () in ["1" , "true" ]:
4242
4343NoiseType = Union [Literal ["gaussian" , "uniform" ], None ]
4444
45- CacheType = Literal ["conditional" , "unconditional" ] # Guidance caches
46-
4745MaskArray = Union [
4846 Float [Array , "s s" ], Int [Array , "s s" ], Bool [Array , "s s" ]
4947]
@@ -123,23 +121,27 @@ def clear_and_get_results_dir(
123121 run_dir : Optional [Path ] = None ,
124122 clear_old : bool = False
125123) -> Path :
124+
126125 if not exists (run_dir ):
127126 run_dir = Path .cwd ()
128127
129128 # Image save directories
130129 imgs_dir = run_dir / "imgs" / dataset_name .lower ()
131130
131+ # Clear old ones
132132 if clear_old :
133- rmtree (str (imgs_dir ), ignore_errors = True ) # Clear old ones
133+ rmtree (str (imgs_dir ), ignore_errors = True )
134134
135135 if not imgs_dir .exists ():
136-
137136 imgs_dir .mkdir (exist_ok = True , parents = True )
138137
139- for _dir in ["samples" , "warps" , "latents" ]:
140- (imgs_dir / _dir ).mkdir (exist_ok = True )
138+ # Image type directories
139+ for _dir in ["samples" , "warps" , "latents" ]:
140+ (imgs_dir / _dir ).mkdir (exist_ok = True , parents = True )
141141
142- return imgs_dir
142+ print ("Saving samples in:\n \t " , imgs_dir )
143+
144+ return imgs_dir
143145
144146
145147def count_parameters (model : eqx .Module ) -> int :
@@ -289,6 +291,7 @@ def __init__(
289291 key : PRNGKeyArray
290292 ):
291293 key_weight , key_bias = jr .split (key )
294+
292295 l = math .sqrt (1. / in_size )
293296 dtype = default (dtype , jnp .float32 )
294297
@@ -432,7 +435,7 @@ def __call__(
432435 mask : Optional [Union [MaskArray , Literal ["causal" ]]],
433436 state : Optional [eqx .nn .State ],
434437 * ,
435- which_cache : CacheType ,
438+ which_cache : KQVCacheType ,
436439 attention_temperature : Optional [float ] = 1.
437440 ) -> Tuple [
438441 Float [Array , "#s q" ], Optional [eqx .nn .State ] # Autoregression
@@ -480,6 +483,7 @@ def __init__(
480483 key : PRNGKeyArray
481484 ):
482485 keys = jr .split (key , 3 )
486+
483487 self .y_dim = y_dim
484488 self .conditioning_type = conditioning_type
485489
@@ -575,7 +579,7 @@ def __call__(
575579 ] = None ,
576580 state : Optional [eqx .nn .State ] = None , # No state during forward pass
577581 * ,
578- which_cache : CacheType = "conditional" ,
582+ which_cache : KQVCacheType = "conditional" ,
579583 attention_temperature : Optional [float ] = 1.
580584 ) -> Union [
581585 Float [Array , "#{self.n_patches} {self.sequence_dim}" ],
@@ -613,9 +617,10 @@ def __init__(
613617 sequence_length : int
614618 ):
615619 self .permute = permute # Flip if true else pass
616- assert jnp .isscalar (self .permute )
617620 self .sequence_length = sequence_length
618621
622+ assert jnp .isscalar (self .permute )
623+
619624 @property
620625 def permute_idx (self ):
621626 permute = maybe_stop_grad (self .permute , stop = True )
@@ -786,7 +791,7 @@ def reverse_step(
786791 s : Int [Array , "" ],
787792 state : eqx .nn .State ,
788793 * ,
789- which_cache : CacheType = "conditional" ,
794+ which_cache : KQVCacheType = "conditional" ,
790795 attention_temperature : Optional [float ] = 1.
791796 ) -> Tuple [
792797 Float [Array , "1 {self.sequence_dim}" ],
@@ -847,7 +852,7 @@ def reverse(
847852 ],
848853 state : eqx .nn .State ,
849854 * ,
850- which_cache : CacheType = "conditional" ,
855+ which_cache : KQVCacheType = "conditional" ,
851856 guidance : float = 0. ,
852857 attention_temperature : Optional [float ] = 1.0 ,
853858 guide_what : Optional [Literal ["ab" , "a" , "b" ]] = "ab" ,
@@ -942,7 +947,7 @@ class TransformerFlow(eqx.Module):
942947 @typecheck
943948 def __init__ (
944949 self ,
945- in_channels : int ,
950+ n_channels : int ,
946951 img_size : int ,
947952 patch_size : int ,
948953 channels : int ,
@@ -958,11 +963,11 @@ def __init__(
958963 key : PRNGKeyArray
959964 ):
960965 self .img_size = img_size
961- self .n_channels = in_channels
966+ self .n_channels = n_channels
962967
963968 self .patch_size = patch_size
964969 self .n_patches = int (img_size / patch_size ) ** 2
965- self .sequence_dim = in_channels * patch_size ** 2
970+ self .sequence_dim = n_channels * patch_size ** 2
966971 self .n_blocks = n_blocks
967972
968973 self .y_dim = y_dim
@@ -1785,6 +1790,7 @@ def filter_spikes(l: list, loss_max: float = 10.0) -> list[float]:
17851790 plt .savefig (imgs_dir / "losses.png" , bbox_inches = "tight" )
17861791 plt .close ()
17871792
1788- save_fn (model = ema_model if use_ema else model )
1793+ if exists (save_fn ):
1794+ save_fn (model = ema_model if use_ema else model )
17891795
17901796 return model
0 commit comments