Skip to content

Commit e202f27

Browse files
committed
Use equinox mha
1 parent cfdd7b5 commit e202f27

10 files changed

Lines changed: 468 additions & 506 deletions

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,3 @@ wandb
166166
__pycache__/
167167
*.pth
168168
*.h5
169-
nanotabpfn_summary.png

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ The purpose of this repository is to be a good starting point for students and r
1313
- `train.py` implements a simple training loop and prior dump data loader in under 200 lines
1414
- `experiment.ipynb` will recreate the experiment from the paper
1515

16-
1716
### Pretrain your own nanoTabPFN
1817

1918
To pretrain your own nanoTabPFN, you need to first download a prior data dump from [here](http://ml.informatik.uni-freiburg.de/research-artifacts/nanoTabPFN/300k_150x5_2.h5), then run `train.py`.
@@ -22,19 +21,25 @@ To pretrain your own nanoTabPFN, you need to first download a prior data dump fr
2221
cd nanoTabPFN
2322

2423
# download data dump
25-
curl http://ml.informatik.uni-freiburg.de/research-artifacts/nanoTabPFN/300k_150x5_2.h5 --output 300k_150x5_2.h5
24+
curl http://ml.informatik.uni-freiburg.de/research-artifacts/nanoTabPFN/300k_150x5_2.h5 --output 300k_150x5.h5
2625

2726
python train.py
2827
```
2928

29+
### Evaluation
30+
31+
![Evaluation results](nanotabpfn_summary.png)
32+
33+
We repeated training for two seeds only; the paper uses 20. To reproduce these results, run `experiment.ipynb`.
34+
3035
#### Step by Step explanation:
3136

3237
First we import our code from model.py and train.py
3338
```py
3439
from model import NanoTabPFNModel
3540
from model import NanoTabPFNClassifier
3641
from train import PriorDumpDataLoader
37-
from train import train, get_default_device
42+
from train import train
3843
```
3944
Then we instantiate our model
4045
```py
@@ -56,17 +61,15 @@ prior = PriorDumpDataLoader(
5661
```
5762
Now we can train our model:
5863
```py
59-
device = get_default_device()
6064
model, _ = train(
6165
model,
6266
prior,
6367
lr = 4e-3,
64-
device = device
6568
)
6669
```
6770
and finally we can instantiate our classifier:
6871
```py
69-
clf = NanoTabPFNClassifier(model, device)
72+
clf = NanoTabPFNClassifier(model)
7073
```
7174
and use its `.fit`, `.predict` and `.predict_proba`:
7275
```py

experiment.ipynb

Lines changed: 202 additions & 202 deletions
Large diffs are not rendered by default.

model.py

Lines changed: 34 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -83,105 +83,50 @@ class TransformerEncoderLayer(eqx.Module):
8383
nhead: int
8484
mlp_hidden_size: int
8585

86-
# Self-attention layers
87-
self_attn_datapoints_q: eqx.nn.Linear
88-
self_attn_datapoints_k: eqx.nn.Linear
89-
self_attn_datapoints_v: eqx.nn.Linear
90-
self_attn_datapoints_out: eqx.nn.Linear
91-
92-
self_attn_features_q: eqx.nn.Linear
93-
self_attn_features_k: eqx.nn.Linear
94-
self_attn_features_v: eqx.nn.Linear
95-
self_attn_features_out: eqx.nn.Linear
96-
97-
# MLP layers
86+
self_attn_features: eqx.nn.MultiheadAttention
87+
self_attn_datapoints: eqx.nn.MultiheadAttention
88+
9889
linear1: eqx.nn.Linear
9990
linear2: eqx.nn.Linear
10091

101-
# Layer norms
10292
norm1: eqx.nn.LayerNorm
10393
norm2: eqx.nn.LayerNorm
10494
norm3: eqx.nn.LayerNorm
10595

10696
def __init__(self, embedding_size: int, nhead: int, mlp_hidden_size: int, *, key: PRNGKeyArray) -> None:
107-
keys = jax.random.split(key, 11)
97+
keys = jax.random.split(key, 5)
10898

10999
self.embedding_size = embedding_size
110100
self.nhead = nhead
111101
self.mlp_hidden_size = mlp_hidden_size
112102

113-
# Self-attention between datapoints
114-
self.self_attn_datapoints_q = eqx.nn.Linear(embedding_size, embedding_size, key=keys[0])
115-
self.self_attn_datapoints_k = eqx.nn.Linear(embedding_size, embedding_size, key=keys[1])
116-
self.self_attn_datapoints_v = eqx.nn.Linear(embedding_size, embedding_size, key=keys[2])
117-
self.self_attn_datapoints_out = eqx.nn.Linear(embedding_size, embedding_size, key=keys[3])
118-
119-
# Self-attention between features
120-
self.self_attn_features_q = eqx.nn.Linear(embedding_size, embedding_size, key=keys[4])
121-
self.self_attn_features_k = eqx.nn.Linear(embedding_size, embedding_size, key=keys[5])
122-
self.self_attn_features_v = eqx.nn.Linear(embedding_size, embedding_size, key=keys[6])
123-
self.self_attn_features_out = eqx.nn.Linear(embedding_size, embedding_size, key=keys[7])
124-
125-
# MLP
126-
self.linear1 = eqx.nn.Linear(embedding_size, mlp_hidden_size, key=keys[8])
127-
self.linear2 = eqx.nn.Linear(mlp_hidden_size, embedding_size, key=keys[9])
103+
self.self_attn_features = eqx.nn.MultiheadAttention(
104+
num_heads=nhead,
105+
query_size=embedding_size,
106+
use_query_bias=True,
107+
use_key_bias=True,
108+
use_value_bias=True,
109+
use_output_bias=True,
110+
key=keys[0],
111+
)
112+
113+
self.self_attn_datapoints = eqx.nn.MultiheadAttention(
114+
num_heads=nhead,
115+
query_size=embedding_size,
116+
use_query_bias=True,
117+
use_key_bias=True,
118+
use_value_bias=True,
119+
use_output_bias=True,
120+
key=keys[1],
121+
)
122+
123+
self.linear1 = eqx.nn.Linear(embedding_size, mlp_hidden_size, key=keys[2])
124+
self.linear2 = eqx.nn.Linear(mlp_hidden_size, embedding_size, key=keys[3])
128125

129-
# Layer norms
130126
self.norm1 = eqx.nn.LayerNorm(embedding_size)
131127
self.norm2 = eqx.nn.LayerNorm(embedding_size)
132128
self.norm3 = eqx.nn.LayerNorm(embedding_size)
133129

134-
def _multihead_attention_features(
135-
self,
136-
query: Float[Array, "seq_len embed_dim"],
137-
key: Float[Array, "seq_len embed_dim"],
138-
value: Float[Array, "seq_len embed_dim"],
139-
) -> Float[Array, "seq_len embed_dim"]:
140-
"""Compute multi-head attention for features."""
141-
seq_len, embed_dim = query.shape
142-
head_dim = embed_dim // self.nhead
143-
144-
q = jax.vmap(self.self_attn_features_q)(query)
145-
k = jax.vmap(self.self_attn_features_k)(key)
146-
v = jax.vmap(self.self_attn_features_v)(value)
147-
148-
# Reshape for multi-head: (seq_len, nhead, head_dim)
149-
q = q.reshape(seq_len, self.nhead, head_dim)
150-
k = k.reshape(key.shape[0], self.nhead, head_dim)
151-
v = v.reshape(value.shape[0], self.nhead, head_dim)
152-
153-
attn_out = jax.nn.dot_product_attention(q, k, v, mask=None, implementation="xla")
154-
155-
attn_out = attn_out.reshape(seq_len, embed_dim)
156-
157-
return jax.vmap(self.self_attn_features_out)(attn_out)
158-
159-
def _multihead_attention_datapoints(
160-
self,
161-
query: Float[Array, "seq_len embed_dim"],
162-
key: Float[Array, "seq_len embed_dim"],
163-
value: Float[Array, "seq_len embed_dim"],
164-
mask: Float[Array, "..."] | None = None,
165-
) -> Float[Array, "seq_len embed_dim"]:
166-
"""Compute multi-head attention for datapoints."""
167-
seq_len, embed_dim = query.shape
168-
head_dim = embed_dim // self.nhead
169-
170-
q = jax.vmap(self.self_attn_datapoints_q)(query)
171-
k = jax.vmap(self.self_attn_datapoints_k)(key)
172-
v = jax.vmap(self.self_attn_datapoints_v)(value)
173-
174-
# Reshape for multi-head: (seq_len, nhead, head_dim)
175-
q = q.reshape(seq_len, self.nhead, head_dim)
176-
k = k.reshape(key.shape[0], self.nhead, head_dim)
177-
v = v.reshape(value.shape[0], self.nhead, head_dim)
178-
179-
attn_out = jax.nn.dot_product_attention(q, k, v, mask=mask, implementation="xla")
180-
181-
attn_out = attn_out.reshape(seq_len, embed_dim)
182-
183-
return jax.vmap(self.self_attn_datapoints_out)(attn_out)
184-
185130
def __call__(
186131
self,
187132
src: Float[Array, "num_rows num_features_plus_target embedding_size"],
@@ -200,16 +145,16 @@ def __call__(
200145
Returns:
201146
(num_rows, num_features+1, embedding_size)
202147
"""
203-
src_features = jax.vmap(self._multihead_attention_features)(src, src, src) + src
148+
src_features = jax.vmap(self.self_attn_features)(src, src, src) + src
204149
src = jax.vmap(jax.vmap(self.norm1))(src_features)
205150

206151
src = jnp.transpose(src, (1, 0, 2))
207152

208-
mask = train_mask[None, :] # (1, rows_size) - broadcasts to (nhead, rows, rows)
153+
num_rows = src.shape[1]
154+
mask = jnp.broadcast_to(train_mask, (num_rows, num_rows))
209155

210-
mha = partial(self._multihead_attention_datapoints, mask=mask)
211-
src_attended = jax.vmap(mha)(src, src, src)
212-
src = src_attended + src
156+
masked_mha = partial(self.self_attn_datapoints, mask=mask)
157+
src = jax.vmap(masked_mha)(src, src, src) + src
213158

214159
src = jnp.transpose(src, (1, 0, 2))
215160

@@ -291,7 +236,6 @@ def __call__(
291236
Returns:
292237
logits of shape (test_size, num_outputs) for test datapoints only
293238
"""
294-
# Ensure y_src has the right shape
295239
if len(y_src.shape) < len(x_src.shape):
296240
y_src = y_src[..., None]
297241

@@ -305,7 +249,6 @@ def __call__(
305249

306250
output = self.decoder(src[:, -1, :])
307251

308-
# Mask out train predictions, keep only test predictions
309252
test_mask = (~train_mask)[:, None] # (num_rows, 1)
310253
output = output * test_mask
311254

@@ -346,27 +289,22 @@ def predict_proba(self, X_test: np.ndarray) -> np.ndarray:
346289
"""
347290
x = jnp.concatenate((self.X_train, X_test))
348291

349-
# Pad features to fixed size (10) to avoid recompilation
350292
num_features = x.shape[1]
351293
if x.shape[1] < 10:
352294
padding = jnp.zeros((x.shape[0], 10 - num_features))
353295
x = jnp.concatenate([x, padding], axis=1)
354296

355-
# Pad targets with mean imputation for test positions
356-
mean = self.y_train.mean() # Scalar mean of training targets
297+
mean = self.y_train.mean()
357298
num_test = len(X_test)
358-
padding = np.full(num_test, mean) # (num_test,) filled with mean
359-
y = jnp.concatenate([self.y_train, padding]) # (num_total,)
299+
padding = np.full(num_test, mean)
300+
y = jnp.concatenate([self.y_train, padding])
360301

361302
num_train = len(self.X_train)
362303
train_mask = jnp.arange(len(x)) < num_train
363304

364305
out = predict(self.model, x, y, train_mask=train_mask)
365306

366-
# Extract only test predictions (train predictions are zeroed out)
367307
out = out[num_train:]
368-
369-
# Slice to keep only valid classes
370308
out = out[:, : self.num_classes]
371309

372310
probabilities = jax.nn.softmax(out, axis=1)

nanotabpfn_summary.png

386 KB
Loading

0 commit comments

Comments
 (0)