Skip to content

Commit cfdd7b5

Browse files
committed
Clean up
1 parent d494878 commit cfdd7b5

14 files changed

Lines changed: 2029 additions & 2891 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ wandb
166166
__pycache__/
167167
*.pth
168168
*.h5
169+
nanotabpfn_summary.png

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# nanoTabPFN
22

3+
> [!NOTE]
4+
> This is a reimplementation of [nanoTabPFN](https://github.com/automl/nanoTabPFN) in JAX, please cite their work. To avoid recompilation for different numbers of input features, a padding was added.
5+
36
Train your own small TabPFN in less than 500 LOC and a few minutes.
47

58
The purpose of this repository is to be a good starting point for students and researchers that are interested in learning about how TabPFN works under the hood.
69

7-
Clone the repository, afterwards install dependencies via:
8-
```
9-
pip install numpy torch schedulefree h5py scikit-learn openml seaborn
10-
```
11-
1210
### Our Code
1311

1412
- `model.py` contains the implementation of the architecture and a sklearn-like interface in less than 200 lines of code.

experiment.ipynb

Lines changed: 211 additions & 2868 deletions
Large diffs are not rendered by default.

model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def __call__(
200200
Returns:
201201
(num_rows, num_features+1, embedding_size)
202202
"""
203-
print("recompile call nanopfn transformer block")
204-
205203
src_features = jax.vmap(self._multihead_attention_features)(src, src, src) + src
206204
src = jax.vmap(jax.vmap(self.norm1))(src_features)
207205

@@ -293,8 +291,6 @@ def __call__(
293291
Returns:
294292
logits of shape (test_size, num_outputs) for test datapoints only
295293
"""
296-
print("recompile call nanopfn model")
297-
298294
# Ensure y_src has the right shape
299295
if len(y_src.shape) < len(x_src.shape):
300296
y_src = y_src[..., None]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ disable_error_code = ["import-untyped"]
4141
[tool.ruff]
4242
line-length = 119
4343
indent-width = 4
44-
extend-exclude = ["experiment.ipynb"]
44+
extend-exclude = ["experiment.ipynb", "tests", "torch_impl"]
4545

4646
[tool.ruff.pydocstyle]
4747
convention = "google"

tests/test_classifier.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
"""Numerical validation tests for full NanoTabPFNModel: JAX vs PyTorch."""
2+
3+
import equinox as eqx
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
import pytest
8+
import torch
9+
10+
from model import NanoTabPFNModel as JAXNanoTabPFN
11+
from torch_impl.model import NanoTabPFNModel as TorchNanoTabPFN
12+
13+
14+
def _copy_torch_linear_to_jax(jax_linear: eqx.nn.Linear, torch_linear: torch.nn.Linear) -> eqx.nn.Linear:
15+
"""Copy weights and biases from PyTorch Linear to JAX Linear."""
16+
weight = torch_linear.weight.detach().cpu().numpy()
17+
bias = torch_linear.bias.detach().cpu().numpy()
18+
jax_linear = eqx.tree_at(lambda m: m.weight, jax_linear, jnp.array(weight))
19+
jax_linear = eqx.tree_at(lambda m: m.bias, jax_linear, jnp.array(bias))
20+
return jax_linear
21+
22+
23+
def _copy_torch_layernorm_to_jax(jax_ln: eqx.nn.LayerNorm, torch_ln: torch.nn.LayerNorm) -> eqx.nn.LayerNorm:
24+
"""Copy weights and biases from PyTorch LayerNorm to JAX LayerNorm."""
25+
weight = torch_ln.weight.detach().cpu().numpy()
26+
bias = torch_ln.bias.detach().cpu().numpy()
27+
jax_ln = eqx.tree_at(lambda m: m.weight, jax_ln, jnp.array(weight))
28+
jax_ln = eqx.tree_at(lambda m: m.bias, jax_ln, jnp.array(bias))
29+
return jax_ln
30+
31+
32+
def _copy_torch_mha_to_jax(jax_layer, torch_mha, is_features: bool = True):
33+
"""Copy PyTorch MultiheadAttention weights to JAX attention layers."""
34+
embed_dim = torch_mha.embed_dim
35+
36+
in_proj_weight = torch_mha.in_proj_weight.detach().cpu().numpy()
37+
in_proj_bias = torch_mha.in_proj_bias.detach().cpu().numpy()
38+
out_proj_weight = torch_mha.out_proj.weight.detach().cpu().numpy()
39+
out_proj_bias = torch_mha.out_proj.bias.detach().cpu().numpy()
40+
41+
w_q = in_proj_weight[:embed_dim, :]
42+
w_k = in_proj_weight[embed_dim : 2 * embed_dim, :]
43+
w_v = in_proj_weight[2 * embed_dim :, :]
44+
45+
b_q = in_proj_bias[:embed_dim]
46+
b_k = in_proj_bias[embed_dim : 2 * embed_dim]
47+
b_v = in_proj_bias[2 * embed_dim :]
48+
49+
if is_features:
50+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_q.weight, jax_layer, jnp.array(w_q))
51+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_q.bias, jax_layer, jnp.array(b_q))
52+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_k.weight, jax_layer, jnp.array(w_k))
53+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_k.bias, jax_layer, jnp.array(b_k))
54+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_v.weight, jax_layer, jnp.array(w_v))
55+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_v.bias, jax_layer, jnp.array(b_v))
56+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_out.weight, jax_layer, jnp.array(out_proj_weight))
57+
jax_layer = eqx.tree_at(lambda m: m.self_attn_features_out.bias, jax_layer, jnp.array(out_proj_bias))
58+
else:
59+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_q.weight, jax_layer, jnp.array(w_q))
60+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_q.bias, jax_layer, jnp.array(b_q))
61+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_k.weight, jax_layer, jnp.array(w_k))
62+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_k.bias, jax_layer, jnp.array(b_k))
63+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_v.weight, jax_layer, jnp.array(w_v))
64+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_v.bias, jax_layer, jnp.array(b_v))
65+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_out.weight, jax_layer, jnp.array(out_proj_weight))
66+
jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_out.bias, jax_layer, jnp.array(out_proj_bias))
67+
68+
return jax_layer
69+
70+
71+
def _copy_model_weights(jax_model: JAXNanoTabPFN, torch_model: TorchNanoTabPFN) -> JAXNanoTabPFN:
72+
"""Copy all weights from PyTorch model to JAX model.
73+
74+
Args:
75+
jax_model: JAX model to copy weights into.
76+
torch_model: PyTorch model to copy weights from.
77+
78+
Returns:
79+
JAX model with copied weights.
80+
"""
81+
jax_model = eqx.tree_at(
82+
lambda m: m.feature_encoder.linear_layer,
83+
jax_model,
84+
_copy_torch_linear_to_jax(jax_model.feature_encoder.linear_layer, torch_model.feature_encoder.linear_layer),
85+
)
86+
87+
jax_model = eqx.tree_at(
88+
lambda m: m.target_encoder.linear_layer,
89+
jax_model,
90+
_copy_torch_linear_to_jax(jax_model.target_encoder.linear_layer, torch_model.target_encoder.linear_layer),
91+
)
92+
93+
for i, torch_block in enumerate(torch_model.transformer_blocks):
94+
jax_block = jax_model.transformer_blocks[i]
95+
96+
jax_block = _copy_torch_mha_to_jax(jax_block, torch_block.self_attention_between_features, is_features=True)
97+
jax_block = _copy_torch_mha_to_jax(jax_block, torch_block.self_attention_between_datapoints, is_features=False)
98+
99+
jax_block = eqx.tree_at(
100+
lambda m: m.linear1,
101+
jax_block,
102+
_copy_torch_linear_to_jax(jax_block.linear1, torch_block.linear1),
103+
)
104+
jax_block = eqx.tree_at(
105+
lambda m: m.linear2,
106+
jax_block,
107+
_copy_torch_linear_to_jax(jax_block.linear2, torch_block.linear2),
108+
)
109+
110+
jax_block = eqx.tree_at(
111+
lambda m: m.norm1,
112+
jax_block,
113+
_copy_torch_layernorm_to_jax(jax_block.norm1, torch_block.norm1),
114+
)
115+
jax_block = eqx.tree_at(
116+
lambda m: m.norm2,
117+
jax_block,
118+
_copy_torch_layernorm_to_jax(jax_block.norm2, torch_block.norm2),
119+
)
120+
jax_block = eqx.tree_at(
121+
lambda m: m.norm3,
122+
jax_block,
123+
_copy_torch_layernorm_to_jax(jax_block.norm3, torch_block.norm3),
124+
)
125+
126+
jax_model = eqx.tree_at(lambda m: m.transformer_blocks[i], jax_model, jax_block)
127+
128+
jax_model = eqx.tree_at(
129+
lambda m: m.decoder.linear1,
130+
jax_model,
131+
_copy_torch_linear_to_jax(jax_model.decoder.linear1, torch_model.decoder.linear1),
132+
)
133+
jax_model = eqx.tree_at(
134+
lambda m: m.decoder.linear2,
135+
jax_model,
136+
_copy_torch_linear_to_jax(jax_model.decoder.linear2, torch_model.decoder.linear2),
137+
)
138+
139+
return jax_model
140+
141+
142+
@pytest.fixture
143+
def full_model_setup() -> dict:
144+
"""Create matched JAX and PyTorch full models with test data.
145+
146+
Returns:
147+
Dictionary containing models, test data, and configuration.
148+
"""
149+
embedding_size = 16
150+
num_attention_heads = 4
151+
mlp_hidden_size = 32
152+
num_layers = 2
153+
num_outputs = 2
154+
155+
num_rows = 8
156+
num_features = 3
157+
train_test_split_index = 5
158+
159+
np.random.seed(42)
160+
x_np = np.random.randn(num_rows, num_features).astype(np.float32)
161+
y_np = np.random.randint(0, num_outputs, size=(num_rows,)).astype(np.float32)
162+
163+
torch_model = TorchNanoTabPFN(embedding_size, num_attention_heads, mlp_hidden_size, num_layers, num_outputs)
164+
torch_model.eval()
165+
166+
key = jax.random.PRNGKey(0)
167+
jax_model = JAXNanoTabPFN(embedding_size, num_attention_heads, mlp_hidden_size, num_layers, num_outputs, key=key)
168+
jax_model = _copy_model_weights(jax_model, torch_model)
169+
170+
return {
171+
"torch_model": torch_model,
172+
"jax_model": jax_model,
173+
"x_np": x_np,
174+
"y_np": y_np,
175+
"num_rows": num_rows,
176+
"num_outputs": num_outputs,
177+
"train_test_split_index": train_test_split_index,
178+
}
179+
180+
181+
def test_full_model_output_shape(full_model_setup: dict) -> None:
182+
"""Test that JAX and PyTorch full models produce expected output shapes."""
183+
setup = full_model_setup
184+
train_test_split_index = setup["train_test_split_index"]
185+
186+
x_torch = torch.from_numpy(setup["x_np"])
187+
y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index])
188+
189+
with torch.no_grad():
190+
out_torch = setup["torch_model"](
191+
(x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index
192+
).squeeze(0)
193+
194+
x_jax = jnp.array(setup["x_np"])
195+
y_jax = jnp.array(setup["y_np"])
196+
train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index
197+
198+
out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax)
199+
200+
num_test = setup["num_rows"] - train_test_split_index
201+
assert out_torch.shape == (num_test, setup["num_outputs"])
202+
assert out_jax.shape == (setup["num_rows"], setup["num_outputs"])
203+
204+
205+
def test_full_model_test_outputs_match(full_model_setup: dict) -> None:
206+
"""Test that test sample outputs match between JAX and PyTorch."""
207+
setup = full_model_setup
208+
train_test_split_index = setup["train_test_split_index"]
209+
210+
x_torch = torch.from_numpy(setup["x_np"])
211+
y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index])
212+
213+
with torch.no_grad():
214+
out_torch = setup["torch_model"](
215+
(x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index
216+
).squeeze(0)
217+
out_torch_np = out_torch.cpu().numpy()
218+
219+
x_jax = jnp.array(setup["x_np"])
220+
y_jax = jnp.array(setup["y_np"])
221+
train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index
222+
223+
out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax)
224+
out_jax_test = np.array(out_jax)[train_test_split_index:]
225+
226+
np.testing.assert_allclose(out_jax_test, out_torch_np, atol=1e-3)
227+
228+
229+
def test_full_model_train_outputs_zeroed(full_model_setup: dict) -> None:
230+
"""Test that JAX model zeros out training sample outputs."""
231+
setup = full_model_setup
232+
train_test_split_index = setup["train_test_split_index"]
233+
234+
x_jax = jnp.array(setup["x_np"])
235+
y_jax = jnp.array(setup["y_np"])
236+
train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index
237+
238+
out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax)
239+
out_jax_train = np.array(out_jax)[:train_test_split_index]
240+
241+
np.testing.assert_allclose(out_jax_train, 0.0, atol=1e-10)
242+
243+
244+
def test_full_model_predictions_match(full_model_setup: dict) -> None:
245+
"""Test that argmax predictions match between JAX and PyTorch."""
246+
setup = full_model_setup
247+
train_test_split_index = setup["train_test_split_index"]
248+
249+
x_torch = torch.from_numpy(setup["x_np"])
250+
y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index])
251+
252+
with torch.no_grad():
253+
out_torch = setup["torch_model"](
254+
(x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index
255+
).squeeze(0)
256+
pred_torch = np.argmax(out_torch.cpu().numpy(), axis=-1)
257+
258+
x_jax = jnp.array(setup["x_np"])
259+
y_jax = jnp.array(setup["y_np"])
260+
train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index
261+
262+
out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax)
263+
out_jax_test = np.array(out_jax)[train_test_split_index:]
264+
pred_jax = np.argmax(out_jax_test, axis=-1)
265+
266+
np.testing.assert_array_equal(pred_jax, pred_torch)
267+
268+
269+
def test_full_model_logits_detailed(full_model_setup: dict) -> None:
270+
"""Test detailed logit comparison for first test sample."""
271+
setup = full_model_setup
272+
train_test_split_index = setup["train_test_split_index"]
273+
274+
x_torch = torch.from_numpy(setup["x_np"])
275+
y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index])
276+
277+
with torch.no_grad():
278+
out_torch = setup["torch_model"](
279+
(x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index
280+
).squeeze(0)
281+
out_torch_np = out_torch.cpu().numpy()
282+
283+
x_jax = jnp.array(setup["x_np"])
284+
y_jax = jnp.array(setup["y_np"])
285+
train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index
286+
287+
out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax)
288+
out_jax_test = np.array(out_jax)[train_test_split_index:]
289+
290+
np.testing.assert_allclose(out_jax_test[0], out_torch_np[0], atol=1e-3)

0 commit comments

Comments
 (0)