@@ -83,105 +83,50 @@ class TransformerEncoderLayer(eqx.Module):
8383 nhead : int
8484 mlp_hidden_size : int
8585
86- # Self-attention layers
87- self_attn_datapoints_q : eqx .nn .Linear
88- self_attn_datapoints_k : eqx .nn .Linear
89- self_attn_datapoints_v : eqx .nn .Linear
90- self_attn_datapoints_out : eqx .nn .Linear
91-
92- self_attn_features_q : eqx .nn .Linear
93- self_attn_features_k : eqx .nn .Linear
94- self_attn_features_v : eqx .nn .Linear
95- self_attn_features_out : eqx .nn .Linear
96-
97- # MLP layers
86+ self_attn_features : eqx .nn .MultiheadAttention
87+ self_attn_datapoints : eqx .nn .MultiheadAttention
88+
9889 linear1 : eqx .nn .Linear
9990 linear2 : eqx .nn .Linear
10091
101- # Layer norms
10292 norm1 : eqx .nn .LayerNorm
10393 norm2 : eqx .nn .LayerNorm
10494 norm3 : eqx .nn .LayerNorm
10595
10696 def __init__ (self , embedding_size : int , nhead : int , mlp_hidden_size : int , * , key : PRNGKeyArray ) -> None :
107- keys = jax .random .split (key , 11 )
97+ keys = jax .random .split (key , 5 )
10898
10999 self .embedding_size = embedding_size
110100 self .nhead = nhead
111101 self .mlp_hidden_size = mlp_hidden_size
112102
113- # Self-attention between datapoints
114- self .self_attn_datapoints_q = eqx .nn .Linear (embedding_size , embedding_size , key = keys [0 ])
115- self .self_attn_datapoints_k = eqx .nn .Linear (embedding_size , embedding_size , key = keys [1 ])
116- self .self_attn_datapoints_v = eqx .nn .Linear (embedding_size , embedding_size , key = keys [2 ])
117- self .self_attn_datapoints_out = eqx .nn .Linear (embedding_size , embedding_size , key = keys [3 ])
118-
119- # Self-attention between features
120- self .self_attn_features_q = eqx .nn .Linear (embedding_size , embedding_size , key = keys [4 ])
121- self .self_attn_features_k = eqx .nn .Linear (embedding_size , embedding_size , key = keys [5 ])
122- self .self_attn_features_v = eqx .nn .Linear (embedding_size , embedding_size , key = keys [6 ])
123- self .self_attn_features_out = eqx .nn .Linear (embedding_size , embedding_size , key = keys [7 ])
124-
125- # MLP
126- self .linear1 = eqx .nn .Linear (embedding_size , mlp_hidden_size , key = keys [8 ])
127- self .linear2 = eqx .nn .Linear (mlp_hidden_size , embedding_size , key = keys [9 ])
103+ self .self_attn_features = eqx .nn .MultiheadAttention (
104+ num_heads = nhead ,
105+ query_size = embedding_size ,
106+ use_query_bias = True ,
107+ use_key_bias = True ,
108+ use_value_bias = True ,
109+ use_output_bias = True ,
110+ key = keys [0 ],
111+ )
112+
113+ self .self_attn_datapoints = eqx .nn .MultiheadAttention (
114+ num_heads = nhead ,
115+ query_size = embedding_size ,
116+ use_query_bias = True ,
117+ use_key_bias = True ,
118+ use_value_bias = True ,
119+ use_output_bias = True ,
120+ key = keys [1 ],
121+ )
122+
123+ self .linear1 = eqx .nn .Linear (embedding_size , mlp_hidden_size , key = keys [2 ])
124+ self .linear2 = eqx .nn .Linear (mlp_hidden_size , embedding_size , key = keys [3 ])
128125
129- # Layer norms
130126 self .norm1 = eqx .nn .LayerNorm (embedding_size )
131127 self .norm2 = eqx .nn .LayerNorm (embedding_size )
132128 self .norm3 = eqx .nn .LayerNorm (embedding_size )
133129
134- def _multihead_attention_features (
135- self ,
136- query : Float [Array , "seq_len embed_dim" ],
137- key : Float [Array , "seq_len embed_dim" ],
138- value : Float [Array , "seq_len embed_dim" ],
139- ) -> Float [Array , "seq_len embed_dim" ]:
140- """Compute multi-head attention for features."""
141- seq_len , embed_dim = query .shape
142- head_dim = embed_dim // self .nhead
143-
144- q = jax .vmap (self .self_attn_features_q )(query )
145- k = jax .vmap (self .self_attn_features_k )(key )
146- v = jax .vmap (self .self_attn_features_v )(value )
147-
148- # Reshape for multi-head: (seq_len, nhead, head_dim)
149- q = q .reshape (seq_len , self .nhead , head_dim )
150- k = k .reshape (key .shape [0 ], self .nhead , head_dim )
151- v = v .reshape (value .shape [0 ], self .nhead , head_dim )
152-
153- attn_out = jax .nn .dot_product_attention (q , k , v , mask = None , implementation = "xla" )
154-
155- attn_out = attn_out .reshape (seq_len , embed_dim )
156-
157- return jax .vmap (self .self_attn_features_out )(attn_out )
158-
159- def _multihead_attention_datapoints (
160- self ,
161- query : Float [Array , "seq_len embed_dim" ],
162- key : Float [Array , "seq_len embed_dim" ],
163- value : Float [Array , "seq_len embed_dim" ],
164- mask : Float [Array , "..." ] | None = None ,
165- ) -> Float [Array , "seq_len embed_dim" ]:
166- """Compute multi-head attention for datapoints."""
167- seq_len , embed_dim = query .shape
168- head_dim = embed_dim // self .nhead
169-
170- q = jax .vmap (self .self_attn_datapoints_q )(query )
171- k = jax .vmap (self .self_attn_datapoints_k )(key )
172- v = jax .vmap (self .self_attn_datapoints_v )(value )
173-
174- # Reshape for multi-head: (seq_len, nhead, head_dim)
175- q = q .reshape (seq_len , self .nhead , head_dim )
176- k = k .reshape (key .shape [0 ], self .nhead , head_dim )
177- v = v .reshape (value .shape [0 ], self .nhead , head_dim )
178-
179- attn_out = jax .nn .dot_product_attention (q , k , v , mask = mask , implementation = "xla" )
180-
181- attn_out = attn_out .reshape (seq_len , embed_dim )
182-
183- return jax .vmap (self .self_attn_datapoints_out )(attn_out )
184-
185130 def __call__ (
186131 self ,
187132 src : Float [Array , "num_rows num_features_plus_target embedding_size" ],
@@ -200,16 +145,16 @@ def __call__(
200145 Returns:
201146 (num_rows, num_features+1, embedding_size)
202147 """
203- src_features = jax .vmap (self ._multihead_attention_features )(src , src , src ) + src
148+ src_features = jax .vmap (self .self_attn_features )(src , src , src ) + src
204149 src = jax .vmap (jax .vmap (self .norm1 ))(src_features )
205150
206151 src = jnp .transpose (src , (1 , 0 , 2 ))
207152
208- mask = train_mask [None , :] # (1, rows_size) - broadcasts to (nhead, rows, rows)
153+ num_rows = src .shape [1 ]
154+ mask = jnp .broadcast_to (train_mask , (num_rows , num_rows ))
209155
210- mha = partial (self ._multihead_attention_datapoints , mask = mask )
211- src_attended = jax .vmap (mha )(src , src , src )
212- src = src_attended + src
156+ masked_mha = partial (self .self_attn_datapoints , mask = mask )
157+ src = jax .vmap (masked_mha )(src , src , src ) + src
213158
214159 src = jnp .transpose (src , (1 , 0 , 2 ))
215160
@@ -291,7 +236,6 @@ def __call__(
291236 Returns:
292237 logits of shape (test_size, num_outputs) for test datapoints only
293238 """
294- # Ensure y_src has the right shape
295239 if len (y_src .shape ) < len (x_src .shape ):
296240 y_src = y_src [..., None ]
297241
@@ -305,7 +249,6 @@ def __call__(
305249
306250 output = self .decoder (src [:, - 1 , :])
307251
308- # Mask out train predictions, keep only test predictions
309252 test_mask = (~ train_mask )[:, None ] # (num_rows, 1)
310253 output = output * test_mask
311254
@@ -346,27 +289,22 @@ def predict_proba(self, X_test: np.ndarray) -> np.ndarray:
346289 """
347290 x = jnp .concatenate ((self .X_train , X_test ))
348291
349- # Pad features to fixed size (10) to avoid recompilation
350292 num_features = x .shape [1 ]
351293 if x .shape [1 ] < 10 :
352294 padding = jnp .zeros ((x .shape [0 ], 10 - num_features ))
353295 x = jnp .concatenate ([x , padding ], axis = 1 )
354296
355- # Pad targets with mean imputation for test positions
356- mean = self .y_train .mean () # Scalar mean of training targets
297+ mean = self .y_train .mean ()
357298 num_test = len (X_test )
358- padding = np .full (num_test , mean ) # (num_test,) filled with mean
359- y = jnp .concatenate ([self .y_train , padding ]) # (num_total,)
299+ padding = np .full (num_test , mean )
300+ y = jnp .concatenate ([self .y_train , padding ])
360301
361302 num_train = len (self .X_train )
362303 train_mask = jnp .arange (len (x )) < num_train
363304
364305 out = predict (self .model , x , y , train_mask = train_mask )
365306
366- # Extract only test predictions (train predictions are zeroed out)
367307 out = out [num_train :]
368-
369- # Slice to keep only valid classes
370308 out = out [:, : self .num_classes ]
371309
372310 probabilities = jax .nn .softmax (out , axis = 1 )
0 commit comments