1- from typing import Sequence , Optional , Union
1+ from typing import Sequence , Optional , Callable
22import jax
33import jax .numpy as jnp
44import jax .random as jr
55import einops
66import equinox as eqx
7- from jaxtyping import Key , Array
7+ from jaxtyping import Key , Array , Float , jaxtyped
8+ from beartype import beartype as typechecker
89
910
1011class AdaLayerNorm (eqx .Module ):
@@ -134,7 +135,11 @@ class Mixer2d(eqx.Module):
134135 t1 : float
135136 embedding_dim : int
136137 final_activation : callable
138+ img_size : Sequence [int ]
139+ q_dim : int
140+ a_dim : int
137141
142+ @jaxtyped (typechecker = typechecker )
138143 def __init__ (
139144 self ,
140145 img_size : Sequence [int ],
@@ -145,11 +150,11 @@ def __init__(
145150 num_blocks : int ,
146151 t1 : float ,
147152 embedding_dim : int = 8 ,
148- final_activation : Optional [Union [ callable , str ] ] = None ,
153+ final_activation : Optional [Callable | str ] = None ,
149154 q_dim : Optional [int ] = None ,
150155 a_dim : Optional [int ] = None ,
151156 * ,
152- key : Key
157+ key : Key [ jnp . ndarray , "..." ]
153158 ):
154159 """
155160 A 2D MLP Mixer model.
@@ -207,6 +212,10 @@ def __init__(
207212 _input_size = input_size + q_dim if q_dim is not None else input_size
208213 _context_dim = embedding_dim + a_dim if a_dim is not None else embedding_dim
209214
215+ self .img_size = img_size
216+ self .q_dim = q_dim
217+ self .a_dim = a_dim
218+
210219 self .conv_in = eqx .nn .Conv2d (
211220 _input_size ,
212221 hidden_size ,
@@ -237,15 +246,16 @@ def __init__(
237246 self .embedding_dim = embedding_dim
238247 self .final_activation = get_activation_fn (final_activation )
239248
249+ @jaxtyped (typechecker = typechecker )
240250 def __call__ (
241251 self ,
242- t : Union [ float , Array ],
243- y : Array ,
244- q : Optional [Array ] = None ,
245- a : Optional [Array ] = None ,
252+ t : float | Float [ Array , "" ],
253+ y : Float [ Array , "..." ] ,
254+ q : Optional [Float [ Array , "{self.q_dim} ..." ] ] = None ,
255+ a : Optional [Float [ Array , "{self.a_dim}" ] ] = None ,
246256 * ,
247- key : Optional [Key ] = None
248- ) -> Array :
257+ key : Optional [Key [ jnp . ndarray , "..." ] ] = None
258+ ) -> Float [ Array , "..." ] :
249259 _ , height , width = y .shape
250260 t = jnp .atleast_1d (t / self .t1 )
251261 t = get_timestep_embedding (t , embedding_dim = self .embedding_dim )
0 commit comments