|
| 1 | +"""Numerical validation tests for full NanoTabPFNModel: JAX vs PyTorch.""" |
| 2 | + |
| 3 | +import equinox as eqx |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +import numpy as np |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
| 10 | +from model import NanoTabPFNModel as JAXNanoTabPFN |
| 11 | +from torch_impl.model import NanoTabPFNModel as TorchNanoTabPFN |
| 12 | + |
| 13 | + |
| 14 | +def _copy_torch_linear_to_jax(jax_linear: eqx.nn.Linear, torch_linear: torch.nn.Linear) -> eqx.nn.Linear: |
| 15 | + """Copy weights and biases from PyTorch Linear to JAX Linear.""" |
| 16 | + weight = torch_linear.weight.detach().cpu().numpy() |
| 17 | + bias = torch_linear.bias.detach().cpu().numpy() |
| 18 | + jax_linear = eqx.tree_at(lambda m: m.weight, jax_linear, jnp.array(weight)) |
| 19 | + jax_linear = eqx.tree_at(lambda m: m.bias, jax_linear, jnp.array(bias)) |
| 20 | + return jax_linear |
| 21 | + |
| 22 | + |
| 23 | +def _copy_torch_layernorm_to_jax(jax_ln: eqx.nn.LayerNorm, torch_ln: torch.nn.LayerNorm) -> eqx.nn.LayerNorm: |
| 24 | + """Copy weights and biases from PyTorch LayerNorm to JAX LayerNorm.""" |
| 25 | + weight = torch_ln.weight.detach().cpu().numpy() |
| 26 | + bias = torch_ln.bias.detach().cpu().numpy() |
| 27 | + jax_ln = eqx.tree_at(lambda m: m.weight, jax_ln, jnp.array(weight)) |
| 28 | + jax_ln = eqx.tree_at(lambda m: m.bias, jax_ln, jnp.array(bias)) |
| 29 | + return jax_ln |
| 30 | + |
| 31 | + |
| 32 | +def _copy_torch_mha_to_jax(jax_layer, torch_mha, is_features: bool = True): |
| 33 | + """Copy PyTorch MultiheadAttention weights to JAX attention layers.""" |
| 34 | + embed_dim = torch_mha.embed_dim |
| 35 | + |
| 36 | + in_proj_weight = torch_mha.in_proj_weight.detach().cpu().numpy() |
| 37 | + in_proj_bias = torch_mha.in_proj_bias.detach().cpu().numpy() |
| 38 | + out_proj_weight = torch_mha.out_proj.weight.detach().cpu().numpy() |
| 39 | + out_proj_bias = torch_mha.out_proj.bias.detach().cpu().numpy() |
| 40 | + |
| 41 | + w_q = in_proj_weight[:embed_dim, :] |
| 42 | + w_k = in_proj_weight[embed_dim : 2 * embed_dim, :] |
| 43 | + w_v = in_proj_weight[2 * embed_dim :, :] |
| 44 | + |
| 45 | + b_q = in_proj_bias[:embed_dim] |
| 46 | + b_k = in_proj_bias[embed_dim : 2 * embed_dim] |
| 47 | + b_v = in_proj_bias[2 * embed_dim :] |
| 48 | + |
| 49 | + if is_features: |
| 50 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_q.weight, jax_layer, jnp.array(w_q)) |
| 51 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_q.bias, jax_layer, jnp.array(b_q)) |
| 52 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_k.weight, jax_layer, jnp.array(w_k)) |
| 53 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_k.bias, jax_layer, jnp.array(b_k)) |
| 54 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_v.weight, jax_layer, jnp.array(w_v)) |
| 55 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_v.bias, jax_layer, jnp.array(b_v)) |
| 56 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_out.weight, jax_layer, jnp.array(out_proj_weight)) |
| 57 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_features_out.bias, jax_layer, jnp.array(out_proj_bias)) |
| 58 | + else: |
| 59 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_q.weight, jax_layer, jnp.array(w_q)) |
| 60 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_q.bias, jax_layer, jnp.array(b_q)) |
| 61 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_k.weight, jax_layer, jnp.array(w_k)) |
| 62 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_k.bias, jax_layer, jnp.array(b_k)) |
| 63 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_v.weight, jax_layer, jnp.array(w_v)) |
| 64 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_v.bias, jax_layer, jnp.array(b_v)) |
| 65 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_out.weight, jax_layer, jnp.array(out_proj_weight)) |
| 66 | + jax_layer = eqx.tree_at(lambda m: m.self_attn_datapoints_out.bias, jax_layer, jnp.array(out_proj_bias)) |
| 67 | + |
| 68 | + return jax_layer |
| 69 | + |
| 70 | + |
| 71 | +def _copy_model_weights(jax_model: JAXNanoTabPFN, torch_model: TorchNanoTabPFN) -> JAXNanoTabPFN: |
| 72 | + """Copy all weights from PyTorch model to JAX model. |
| 73 | +
|
| 74 | + Args: |
| 75 | + jax_model: JAX model to copy weights into. |
| 76 | + torch_model: PyTorch model to copy weights from. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + JAX model with copied weights. |
| 80 | + """ |
| 81 | + jax_model = eqx.tree_at( |
| 82 | + lambda m: m.feature_encoder.linear_layer, |
| 83 | + jax_model, |
| 84 | + _copy_torch_linear_to_jax(jax_model.feature_encoder.linear_layer, torch_model.feature_encoder.linear_layer), |
| 85 | + ) |
| 86 | + |
| 87 | + jax_model = eqx.tree_at( |
| 88 | + lambda m: m.target_encoder.linear_layer, |
| 89 | + jax_model, |
| 90 | + _copy_torch_linear_to_jax(jax_model.target_encoder.linear_layer, torch_model.target_encoder.linear_layer), |
| 91 | + ) |
| 92 | + |
| 93 | + for i, torch_block in enumerate(torch_model.transformer_blocks): |
| 94 | + jax_block = jax_model.transformer_blocks[i] |
| 95 | + |
| 96 | + jax_block = _copy_torch_mha_to_jax(jax_block, torch_block.self_attention_between_features, is_features=True) |
| 97 | + jax_block = _copy_torch_mha_to_jax(jax_block, torch_block.self_attention_between_datapoints, is_features=False) |
| 98 | + |
| 99 | + jax_block = eqx.tree_at( |
| 100 | + lambda m: m.linear1, |
| 101 | + jax_block, |
| 102 | + _copy_torch_linear_to_jax(jax_block.linear1, torch_block.linear1), |
| 103 | + ) |
| 104 | + jax_block = eqx.tree_at( |
| 105 | + lambda m: m.linear2, |
| 106 | + jax_block, |
| 107 | + _copy_torch_linear_to_jax(jax_block.linear2, torch_block.linear2), |
| 108 | + ) |
| 109 | + |
| 110 | + jax_block = eqx.tree_at( |
| 111 | + lambda m: m.norm1, |
| 112 | + jax_block, |
| 113 | + _copy_torch_layernorm_to_jax(jax_block.norm1, torch_block.norm1), |
| 114 | + ) |
| 115 | + jax_block = eqx.tree_at( |
| 116 | + lambda m: m.norm2, |
| 117 | + jax_block, |
| 118 | + _copy_torch_layernorm_to_jax(jax_block.norm2, torch_block.norm2), |
| 119 | + ) |
| 120 | + jax_block = eqx.tree_at( |
| 121 | + lambda m: m.norm3, |
| 122 | + jax_block, |
| 123 | + _copy_torch_layernorm_to_jax(jax_block.norm3, torch_block.norm3), |
| 124 | + ) |
| 125 | + |
| 126 | + jax_model = eqx.tree_at(lambda m: m.transformer_blocks[i], jax_model, jax_block) |
| 127 | + |
| 128 | + jax_model = eqx.tree_at( |
| 129 | + lambda m: m.decoder.linear1, |
| 130 | + jax_model, |
| 131 | + _copy_torch_linear_to_jax(jax_model.decoder.linear1, torch_model.decoder.linear1), |
| 132 | + ) |
| 133 | + jax_model = eqx.tree_at( |
| 134 | + lambda m: m.decoder.linear2, |
| 135 | + jax_model, |
| 136 | + _copy_torch_linear_to_jax(jax_model.decoder.linear2, torch_model.decoder.linear2), |
| 137 | + ) |
| 138 | + |
| 139 | + return jax_model |
| 140 | + |
| 141 | + |
| 142 | +@pytest.fixture |
| 143 | +def full_model_setup() -> dict: |
| 144 | + """Create matched JAX and PyTorch full models with test data. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + Dictionary containing models, test data, and configuration. |
| 148 | + """ |
| 149 | + embedding_size = 16 |
| 150 | + num_attention_heads = 4 |
| 151 | + mlp_hidden_size = 32 |
| 152 | + num_layers = 2 |
| 153 | + num_outputs = 2 |
| 154 | + |
| 155 | + num_rows = 8 |
| 156 | + num_features = 3 |
| 157 | + train_test_split_index = 5 |
| 158 | + |
| 159 | + np.random.seed(42) |
| 160 | + x_np = np.random.randn(num_rows, num_features).astype(np.float32) |
| 161 | + y_np = np.random.randint(0, num_outputs, size=(num_rows,)).astype(np.float32) |
| 162 | + |
| 163 | + torch_model = TorchNanoTabPFN(embedding_size, num_attention_heads, mlp_hidden_size, num_layers, num_outputs) |
| 164 | + torch_model.eval() |
| 165 | + |
| 166 | + key = jax.random.PRNGKey(0) |
| 167 | + jax_model = JAXNanoTabPFN(embedding_size, num_attention_heads, mlp_hidden_size, num_layers, num_outputs, key=key) |
| 168 | + jax_model = _copy_model_weights(jax_model, torch_model) |
| 169 | + |
| 170 | + return { |
| 171 | + "torch_model": torch_model, |
| 172 | + "jax_model": jax_model, |
| 173 | + "x_np": x_np, |
| 174 | + "y_np": y_np, |
| 175 | + "num_rows": num_rows, |
| 176 | + "num_outputs": num_outputs, |
| 177 | + "train_test_split_index": train_test_split_index, |
| 178 | + } |
| 179 | + |
| 180 | + |
| 181 | +def test_full_model_output_shape(full_model_setup: dict) -> None: |
| 182 | + """Test that JAX and PyTorch full models produce expected output shapes.""" |
| 183 | + setup = full_model_setup |
| 184 | + train_test_split_index = setup["train_test_split_index"] |
| 185 | + |
| 186 | + x_torch = torch.from_numpy(setup["x_np"]) |
| 187 | + y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index]) |
| 188 | + |
| 189 | + with torch.no_grad(): |
| 190 | + out_torch = setup["torch_model"]( |
| 191 | + (x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index |
| 192 | + ).squeeze(0) |
| 193 | + |
| 194 | + x_jax = jnp.array(setup["x_np"]) |
| 195 | + y_jax = jnp.array(setup["y_np"]) |
| 196 | + train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index |
| 197 | + |
| 198 | + out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax) |
| 199 | + |
| 200 | + num_test = setup["num_rows"] - train_test_split_index |
| 201 | + assert out_torch.shape == (num_test, setup["num_outputs"]) |
| 202 | + assert out_jax.shape == (setup["num_rows"], setup["num_outputs"]) |
| 203 | + |
| 204 | + |
| 205 | +def test_full_model_test_outputs_match(full_model_setup: dict) -> None: |
| 206 | + """Test that test sample outputs match between JAX and PyTorch.""" |
| 207 | + setup = full_model_setup |
| 208 | + train_test_split_index = setup["train_test_split_index"] |
| 209 | + |
| 210 | + x_torch = torch.from_numpy(setup["x_np"]) |
| 211 | + y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index]) |
| 212 | + |
| 213 | + with torch.no_grad(): |
| 214 | + out_torch = setup["torch_model"]( |
| 215 | + (x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index |
| 216 | + ).squeeze(0) |
| 217 | + out_torch_np = out_torch.cpu().numpy() |
| 218 | + |
| 219 | + x_jax = jnp.array(setup["x_np"]) |
| 220 | + y_jax = jnp.array(setup["y_np"]) |
| 221 | + train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index |
| 222 | + |
| 223 | + out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax) |
| 224 | + out_jax_test = np.array(out_jax)[train_test_split_index:] |
| 225 | + |
| 226 | + np.testing.assert_allclose(out_jax_test, out_torch_np, atol=1e-3) |
| 227 | + |
| 228 | + |
| 229 | +def test_full_model_train_outputs_zeroed(full_model_setup: dict) -> None: |
| 230 | + """Test that JAX model zeros out training sample outputs.""" |
| 231 | + setup = full_model_setup |
| 232 | + train_test_split_index = setup["train_test_split_index"] |
| 233 | + |
| 234 | + x_jax = jnp.array(setup["x_np"]) |
| 235 | + y_jax = jnp.array(setup["y_np"]) |
| 236 | + train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index |
| 237 | + |
| 238 | + out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax) |
| 239 | + out_jax_train = np.array(out_jax)[:train_test_split_index] |
| 240 | + |
| 241 | + np.testing.assert_allclose(out_jax_train, 0.0, atol=1e-10) |
| 242 | + |
| 243 | + |
| 244 | +def test_full_model_predictions_match(full_model_setup: dict) -> None: |
| 245 | + """Test that argmax predictions match between JAX and PyTorch.""" |
| 246 | + setup = full_model_setup |
| 247 | + train_test_split_index = setup["train_test_split_index"] |
| 248 | + |
| 249 | + x_torch = torch.from_numpy(setup["x_np"]) |
| 250 | + y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index]) |
| 251 | + |
| 252 | + with torch.no_grad(): |
| 253 | + out_torch = setup["torch_model"]( |
| 254 | + (x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index |
| 255 | + ).squeeze(0) |
| 256 | + pred_torch = np.argmax(out_torch.cpu().numpy(), axis=-1) |
| 257 | + |
| 258 | + x_jax = jnp.array(setup["x_np"]) |
| 259 | + y_jax = jnp.array(setup["y_np"]) |
| 260 | + train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index |
| 261 | + |
| 262 | + out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax) |
| 263 | + out_jax_test = np.array(out_jax)[train_test_split_index:] |
| 264 | + pred_jax = np.argmax(out_jax_test, axis=-1) |
| 265 | + |
| 266 | + np.testing.assert_array_equal(pred_jax, pred_torch) |
| 267 | + |
| 268 | + |
| 269 | +def test_full_model_logits_detailed(full_model_setup: dict) -> None: |
| 270 | + """Test detailed logit comparison for first test sample.""" |
| 271 | + setup = full_model_setup |
| 272 | + train_test_split_index = setup["train_test_split_index"] |
| 273 | + |
| 274 | + x_torch = torch.from_numpy(setup["x_np"]) |
| 275 | + y_train_torch = torch.from_numpy(setup["y_np"][:train_test_split_index]) |
| 276 | + |
| 277 | + with torch.no_grad(): |
| 278 | + out_torch = setup["torch_model"]( |
| 279 | + (x_torch.unsqueeze(0), y_train_torch.unsqueeze(0)), train_test_split_index |
| 280 | + ).squeeze(0) |
| 281 | + out_torch_np = out_torch.cpu().numpy() |
| 282 | + |
| 283 | + x_jax = jnp.array(setup["x_np"]) |
| 284 | + y_jax = jnp.array(setup["y_np"]) |
| 285 | + train_mask_jax = jnp.arange(setup["num_rows"]) < train_test_split_index |
| 286 | + |
| 287 | + out_jax = setup["jax_model"](x_jax, y_jax, train_mask=train_mask_jax) |
| 288 | + out_jax_test = np.array(out_jax)[train_test_split_index:] |
| 289 | + |
| 290 | + np.testing.assert_allclose(out_jax_test[0], out_torch_np[0], atol=1e-3) |
0 commit comments