11import math
2- from typing import Literal
2+ from typing import Any , Literal
33
44import chex
55from einops import einops
@@ -20,7 +20,7 @@ class FsqCodebook(nn.Module):
2020 _bins_per_dim : tuple [int ] | None = None
2121
2222 @property
23- def bins_per_dim (self ):
23+ def bins_per_dim (self ) -> tuple [ int ] :
2424 if self ._bins_per_dim is not None :
2525 return self ._bins_per_dim
2626
@@ -34,14 +34,14 @@ def bins_per_dim(self):
3434 raise ValueError (f"Codebook type { self .codebook_type } not supported." )
3535
3636 @property
37- def place_values (self ):
37+ def place_values (self ) -> jnp . ndarray :
3838 place_values = [1 ]
3939 for b in self .bins_per_dim [:- 1 ]:
4040 place_values .append (place_values [- 1 ] * b )
4141 return jnp .array (place_values )
4242
4343 @staticmethod
44- def _get_bins_fsq (target_codebook_size ) :
44+ def _get_bins_fsq (target_codebook_size : int ) -> tuple [ int ] :
4545 """
4646 Get bins per dimension based on codebook size, from the original FSQ paper.
4747 """
@@ -59,7 +59,7 @@ def _get_bins_fsq(target_codebook_size):
5959 raise ValueError (f"Codebook size { target_codebook_size } not supported." )
6060
6161 @staticmethod
62- def _get_bins_custom (target_codebook_size ) :
62+ def _get_bins_custom (target_codebook_size : int ) -> tuple [ int ] :
6363 if target_codebook_size == 2 ** 8 :
6464 return (16 , 16 )
6565 elif target_codebook_size == 2 ** 10 : # noqa: RET505
@@ -73,7 +73,7 @@ def _get_bins_custom(target_codebook_size):
7373 return None
7474
7575 @staticmethod
76- def _get_bins_lfq (target_codebook_size ) :
76+ def _get_bins_lfq (target_codebook_size : int ) -> tuple [ int ] :
7777 """
7878 Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
7979 """
@@ -85,12 +85,12 @@ def setup(self):
8585 self .proj_down = nn .Dense (len (self .bins_per_dim ))
8686 self .proj_up = nn .Dense (self .input_dim )
8787
88- def __call__ (self , inputs ) :
88+ def __call__ (self , inputs : jnp . ndarray ) -> tuple [ jnp . ndarray , jnp . ndarray ] :
8989 tokens , z = self .encode (inputs )
9090 output = self .decode (tokens , z_grad = z )
9191 return tokens , output
9292
93- def encode (self , inputs ) :
93+ def encode (self , inputs : jnp . ndarray ) -> tuple [ jnp . ndarray , jnp . ndarray ] :
9494 bases = jnp .array (self .bins_per_dim )
9595
9696 x = self .proj_down (inputs )
@@ -102,7 +102,7 @@ def encode(self, inputs):
102102
103103 return tokens , z
104104
105- def decode (self , tokens , z_grad : jax .Array | None = None ):
105+ def decode (self , tokens : jnp . ndarray , z_grad : jax .Array | None = None ) -> jnp . ndarray :
106106 bases = jnp .array (self .bins_per_dim )
107107 digits = self .digitize (tokens )
108108
@@ -114,14 +114,14 @@ def decode(self, tokens, z_grad: jax.Array | None = None):
114114
115115 return self .proj_up (z_q )
116116
117- def undigitize (self , digits ) :
117+ def undigitize (self , digits : jnp . ndarray ) -> jnp . ndarray :
118118 return jnp .sum (digits * jnp .array (self .place_values ), axis = - 1 )
119119
120- def digitize (self , tokens ) :
120+ def digitize (self , tokens : jnp . ndarray ) -> jnp . ndarray :
121121 return (tokens [..., None ] // jnp .array (self .place_values )) % jnp .array (self .bins_per_dim )
122122
123123 @property
124- def vocab_size (self ):
124+ def vocab_size (self ) -> int :
125125 return math .prod (self .bins_per_dim )
126126
127127
@@ -132,7 +132,7 @@ class ResNetDownBlock(nn.Module):
132132 group_size : int = 32
133133
134134 @nn .compact
135- def __call__ (self , x , * , train = True ):
135+ def __call__ (self , x : jnp . ndarray , * , train : bool = True ) -> jnp . ndarray :
136136 skip = x
137137
138138 if self .stride > 1 or x .shape [- 1 ] != self .n_filters :
@@ -154,7 +154,7 @@ class ResNetUpBlock(nn.Module):
154154 group_size : int = 32
155155
156156 @nn .compact
157- def __call__ (self , x , * , train = True ):
157+ def __call__ (self , x : jnp . ndarray , * , train : bool = True ) -> jnp . ndarray :
158158 skip = x
159159
160160 if self .stride > 1 :
@@ -184,30 +184,29 @@ class LookupFreeQuantization(nn.Module):
184184
185185 def setup (self ):
186186 self .codebook = jnp .array ([- 1 , 1 ])
187- # self.activation = lambda x: x
188187 self .activation = nn .tanh
189188
190189 self .project_down = nn .Dense (self .num_dims )
191190 self .project_up = nn .Dense (self .latent_dim )
192191
193- def encode (self , z ) :
192+ def encode (self , z : jnp . ndarray ) -> jnp . ndarray :
194193 z = self .project_down (z )
195194 token_squared_distances = jnp .square (z [..., None ] - self .codebook )
196195 token_bits = jnp .argmin (token_squared_distances , axis = - 1 )
197196 return jnp .sum (token_bits * (2 ** jnp .arange (self .num_dims )), axis = - 1 )
198197
199- def decode (self , tokens ) :
198+ def decode (self , tokens : jnp . ndarray ) -> jnp . ndarray :
200199 token_bits = (tokens [..., None ] & (2 ** jnp .arange (self .num_dims ))).astype (jnp .int32 )
201200 return self .project_up (self .codebook [token_bits ])
202201
203- def loss (self , x ) :
202+ def loss (self , x : jnp . ndarray ) -> LfqCodebookOutput :
204203 z = self .project_down (x )
205204 z = self .activation (z )
206205
207206 token_squared_distances = jnp .square (z [..., None ] - self .codebook )
208207 tokens = jnp .argmin (token_squared_distances , axis = - 1 )
209208
210- token_bit_log_probs = - token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
209+ token_bit_log_probs = - token_squared_distances
211210 # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
212211 token_bit_expansions = jnp .bitwise_and (
213212 jnp .arange (2 ** self .num_dims )[None , :], 2 ** jnp .arange (self .num_dims )[:, None ]
@@ -236,7 +235,7 @@ def loss(self, x):
236235 )
237236
238237
239- def make_block_causal_attention_matrix (q , k , bs_q , bs_k ) :
238+ def make_block_causal_attention_matrix (q : jnp . ndarray , k : jnp . ndarray , bs_q : int , bs_k : int ) -> jnp . ndarray :
240239 return nn .make_attention_mask (q , k , pairwise_fn = lambda x , y : jnp .greater_equal (x // bs_k , y // bs_q ))
241240
242241
@@ -245,14 +244,7 @@ class GeGLU(Module):
245244 GeGLU is a Flax layer that combines a linear transformation with a GELU
246245 activation function in a gating mechanism. It is often used in Transformer models
247246 to provide non-linear capabilities while preserving a strong linear component.
248- Example usage::
249- >>> import flax.linen as nn
250- >>> class TransformerBlock(nn.Module):
251- ... @nn.compact
252- ... def __call__(self, x):
253- ... x = nn.Dense(2)(x)
254- ... x = nn.GeGLU()(x) # initialized
255- ... return x
247+
256248 Attributes:
257249 features: the number of output features (default: None).
258250 """
@@ -281,7 +273,15 @@ class CrossAttentionLayer(nn.Module):
281273 mlp_ratio : float = 4.0
282274
283275 @nn .compact
284- def __call__ (self , x , y , * , mask_self = None , mask_cross = None , train = True ):
276+ def __call__ (
277+ self ,
278+ x : jnp .ndarray ,
279+ y : jnp .ndarray ,
280+ * ,
281+ mask_self : jnp .ndarray | None = None ,
282+ mask_cross : jnp .ndarray | None = None ,
283+ train : bool = True ,
284+ ) -> jnp .ndarray :
285285 d_embed = x .shape [- 1 ]
286286 seq_len_q = x .shape [- 2 ]
287287 seq_len_k = y .shape [- 2 ]
@@ -307,12 +307,10 @@ def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
307307 # Cross-attention block
308308 skip = x
309309 x = nn .LayerNorm ()(x )
310- # bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
311310 x = nn .MultiHeadDotProductAttention (
312311 num_heads = self .num_heads or d_embed // 64 ,
313312 dropout_rate = self .dropout_rate ,
314313 deterministic = not train ,
315- # attention_fn=partial(nn.dot_product_attention, bias=bias),
316314 )(x , y , y , mask = mask_cross )
317315 x = skip + x
318316
@@ -326,7 +324,7 @@ def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
326324 return skip + x
327325
328326
329- def sinusoidal_pe_init (_ , shape ) :
327+ def sinusoidal_pe_init (_ , shape : tuple [ int , int ]) -> jnp . ndarray :
330328 seq_len , d_embed = shape
331329
332330 position = jnp .arange (0 , seq_len , 1 )
@@ -350,7 +348,14 @@ class TokenizerEncoderDecoder(nn.Module):
350348 use_state_conditioning : bool = False
351349
352350 @nn .compact
353- def __call__ (self , y , * , train = True , state_conditioning = None , mask = None ):
351+ def __call__ (
352+ self ,
353+ y : jnp .ndarray ,
354+ * ,
355+ train : bool = True ,
356+ state_conditioning : jnp .ndarray | None = None ,
357+ mask : jnp .ndarray | None = None ,
358+ ) -> jnp .ndarray :
354359 x = self .param ("q_embed" , sinusoidal_pe_init , (self .num_tokens , y .shape [- 1 ]))
355360 x = jax .numpy .broadcast_to (x , y .shape [:- 2 ] + x .shape [- 2 :])
356361
@@ -392,7 +397,7 @@ class FsqAttentionTokenizer(nn.Module):
392397 use_state_conditioning : bool = False
393398
394399 @property
395- def vocab_size (self ):
400+ def vocab_size (self ) -> int :
396401 return math .prod (FsqCodebook ._get_bins_fsq (self .target_codebook_size )) # noqa: SLF001
397402
398403 def setup (self ):
@@ -422,7 +427,9 @@ def setup(self):
422427 self .proj_mean = nn .Dense (self .data_dim )
423428 self .out_scale = self .param ("out_scale" , lambda _ : jnp .full ((), 1.0 ))
424429
425- def tokenize (self , action , * , obs = None , train = False ):
430+ def tokenize (
431+ self , action : jnp .ndarray , * , obs : jnp .ndarray | None = None , train : bool = False
432+ ) -> tuple [jnp .ndarray , jnp .ndarray ]:
426433 if self .bound is not None :
427434 action = jnp .clip (action , - self .bound , self .bound )
428435
@@ -431,12 +438,14 @@ def tokenize(self, action, *, obs=None, train=False):
431438
432439 return self .codebook .encode (x )
433440
434- def detokenize (self , tokens , * , obs = None ):
441+ def detokenize (self , tokens : jnp . ndarray , * , obs : jnp . ndarray | None = None ) -> jnp . ndarray :
435442 x = self .decoder (self .codebook .decode (tokens ), state_conditioning = obs )
436443 mean = self .proj_mean (x )
437444 return mean * self .out_scale
438445
439- def loss (self , action , * , obs = None , train = True ):
446+ def loss (
447+ self , action : jnp .ndarray , * , obs : jnp .ndarray | None = None , train : bool = True
448+ ) -> tuple [jnp .ndarray , dict [str , jnp .ndarray ]]:
440449 # Encode
441450 x = self .proj (action )
442451 z = self .encoder (x , train = train , state_conditioning = obs )
@@ -456,7 +465,7 @@ def loss(self, action, *, obs=None, train=True):
456465 "mae" : mae ,
457466 }
458467
459- def __call__ (self , * args , ** kwargs ) :
468+ def __call__ (self , * args : Any , ** kwargs : Any ) -> tuple [ jnp . ndarray , dict [ str , jnp . ndarray ]] :
460469 """
461470 Dummy for .init
462471 """
0 commit comments