1+ import os
12import math
23import dataclasses
34from pathlib import Path
2425from .attention import MultiheadAttention , self_attention
2526
2627
27- typecheck = jaxtyped (typechecker = typechecker )
28-
28+ if os .getenv ("TYPECHECK" , "" ).lower () in ["1" , "true" ]:
29+ typecheck = jaxtyped (typechecker = typechecker )
30+ else :
31+ typecheck = lambda _ : _
2932
3033MetricsDict = dict [
3134 str , Union [Scalar , Float [Array , "..." ]]
3942
4043NoiseType = Union [Literal ["gaussian" , "uniform" ], None ]
4144
45+ CacheType = Literal ["conditional" , "unconditional" ] # Guidance caches
46+
4247MaskArray = Union [
4348 Float [Array , "s s" ], Int [Array , "s s" ], Bool [Array , "s s" ]
4449]
4752 Union [Float [Array , "..." ], Int [Array , "..." ]] # Flattened regardless
4853]
4954
50- OptState = PyTree | optax .OptState
55+ OptState = Union [ PyTree , optax .OptState ]
5156
5257
5358def exists (v ):
@@ -193,8 +198,10 @@ def shard_batch(
193198 Tuple [Float [Array , "n ..." ], Float [Array , "n ..." ]],
194199 Float [Array , "n ..." ]
195200]:
201+
196202 if sharding :
197203 batch = eqx .filter_shard (batch , sharding )
204+
198205 return batch
199206
200207
@@ -205,13 +212,19 @@ def shard_model(
205212) -> Union [eqx .Module , Tuple [eqx .Module , OptState ]]:
206213 if sharding :
207214 model = eqx .filter_shard (model , sharding )
215+
208216 if opt_state :
217+
209218 opt_state = eqx .filter_shard (opt_state , sharding )
219+
210220 return model , opt_state
221+
211222 return model
212223 else :
213224 if opt_state :
225+
214226 return model , opt_state
227+
215228 return model
216229
217230
@@ -419,7 +432,7 @@ def __call__(
419432 mask : Optional [Union [MaskArray , Literal ["causal" ]]],
420433 state : Optional [eqx .nn .State ],
421434 * ,
422- which_cache : Literal [ "cond" , "uncond" ] ,
435+ which_cache : CacheType ,
423436 attention_temperature : Optional [float ] = 1.
424437 ) -> Tuple [
425438 Float [Array , "#s q" ], Optional [eqx .nn .State ] # Autoregression
@@ -562,7 +575,7 @@ def __call__(
562575 ] = None ,
563576 state : Optional [eqx .nn .State ] = None , # No state during forward pass
564577 * ,
565- which_cache : Literal [ "cond" , "uncond" ] = "cond " ,
578+ which_cache : CacheType = "conditional " ,
566579 attention_temperature : Optional [float ] = 1.
567580 ) -> Union [
568581 Float [Array , "#{self.n_patches} {self.sequence_dim}" ],
@@ -773,7 +786,7 @@ def reverse_step(
773786 s : Int [Array , "" ],
774787 state : eqx .nn .State ,
775788 * ,
776- which_cache : Literal [ "cond" , "uncond" ] = "cond " ,
789+ which_cache : CacheType = "conditional " ,
777790 attention_temperature : Optional [float ] = 1.
778791 ) -> Tuple [
779792 Float [Array , "1 {self.sequence_dim}" ],
@@ -834,7 +847,7 @@ def reverse(
834847 ],
835848 state : eqx .nn .State ,
836849 * ,
837- which_cache : Literal [ "cond" , "uncond" ] = "cond " ,
850+ which_cache : CacheType = "conditional " ,
838851 guidance : float = 0. ,
839852 attention_temperature : Optional [float ] = 1.0 ,
840853 guide_what : Optional [Literal ["ab" , "a" , "b" ]] = "ab" ,
@@ -875,7 +888,7 @@ def _autoregression_step(
875888 pos_embed ,
876889 s ,
877890 state = state ,
878- which_cache = "uncond " ,
891+ which_cache = "unconditional " ,
879892 attention_temperature = attention_temperature ,
880893 )
881894
@@ -1155,6 +1168,7 @@ def reverse(
11551168
11561169 def _block_step (z_s_sequence , params__state ):
11571170 z , s , sequence = z_s_sequence
1171+
11581172 params , state = params__state
11591173 block = eqx .combine (params , struct )
11601174
@@ -1173,6 +1187,7 @@ def _block_step(z_s_sequence, params__state):
11731187 return (z , s + 1 , sequence ), None
11741188
11751189 sequence = jnp .zeros ((self .n_blocks + 1 , self .n_channels , self .img_size , self .img_size ), dtype = z .dtype )
1190+
11761191 sequence = sequence .at [0 ].set (self .unpatchify (z ))
11771192
11781193 (z , _ , sequence ), _ = jax .lax .scan (
@@ -1512,6 +1527,7 @@ def train(
15121527 dataset : Dataset ,
15131528 # Model
15141529 model : TransformerFlow ,
1530+ state : eqx .nn .State ,
15151531 eps_sigma : Optional [float ],
15161532 noise_type : NoiseType ,
15171533 # Data
@@ -1539,7 +1555,6 @@ def train(
15391555 n_sample : Optional [int ] = 4 ,
15401556 n_warps : Optional [int ] = 1 ,
15411557 denoise_samples : bool = False ,
1542- get_state_fn : Callable [[None ], eqx .nn .State ] = None ,
15431558 cmap : Optional [str ] = None ,
15441559 # Sharding: data and model
15451560 sharding : Optional [NamedSharding ] = None ,
@@ -1550,7 +1565,7 @@ def train(
15501565
15511566 print ("n_params={:.3E}" .format (count_parameters (model )))
15521567
1553- key_data , valid_key , sample_key , * loader_keys = jr .split (key , 5 )
1568+ valid_key , sample_key , * loader_keys = jr .split (key , 4 )
15541569
15551570 # Optimiser & scheduler
15561571 n_steps_per_epoch = int ((dataset .x_train .shape [0 ] + dataset .x_valid .shape [0 ]) / batch_size )
@@ -1691,7 +1706,7 @@ def train(
16911706 ema_model if use_ema else model ,
16921707 z ,
16931708 y ,
1694- state = get_state_fn () ,
1709+ state = state ,
16951710 guidance = guidance ,
16961711 denoise_samples = denoise_samples ,
16971712 sharding = sharding ,
@@ -1722,7 +1737,7 @@ def train(
17221737 ema_model if use_ema else model ,
17231738 z ,
17241739 y ,
1725- state = get_state_fn (),
1740+ state = state ,
17261741 return_sequence = True ,
17271742 denoise_samples = denoise_samples ,
17281743 sharding = sharding ,
@@ -1744,7 +1759,6 @@ def train(
17441759 plt .savefig (imgs_dir / "warps/warps_{:05d}.png" .format (i ), bbox_inches = "tight" )
17451760 plt .close ()
17461761
1747-
17481762 # Losses and metrics
17491763 if i > 0 :
17501764
0 commit comments