|
| 1 | +import jax |
| 2 | +import jax.numpy as jnp |
| 3 | +import pytest |
| 4 | + |
| 5 | +from cmonge.models.nn import ConditionalPerturbationNetwork |
| 6 | + |
| 7 | +# (context_bonds, dim_cond, num_contexts) |
| 8 | +CONTEXT_BOND_CONFIGS = [ |
| 9 | + pytest.param( |
| 10 | + ((0, 10), (10, 20)), |
| 11 | + 20, |
| 12 | + 2, |
| 13 | + id="non_overlapping_2_modalities", |
| 14 | + ), |
| 15 | + pytest.param( |
| 16 | + ((0, 10), (0, 10)), |
| 17 | + 10, |
| 18 | + 2, |
| 19 | + id="overlapping_2_modalities", |
| 20 | + ), |
| 21 | + pytest.param( |
| 22 | + ((0, 10), (10, 20), (20, 30)), |
| 23 | + 30, |
| 24 | + 3, |
| 25 | + id="non_overlapping_3_modalities", |
| 26 | + ), |
| 27 | +] |
| 28 | + |
| 29 | +DIM_DATA = 16 |
| 30 | +DIM_HIDDEN = [32, 32] |
| 31 | +DIM_COND_MAP = (8,) |
| 32 | +BATCH_SIZE = 4 |
| 33 | + |
| 34 | + |
| 35 | +def _make_model(context_bonds, attention_pooling, dropout_rate=0.1): |
| 36 | + return ConditionalPerturbationNetwork( |
| 37 | + dim_hidden=DIM_HIDDEN, |
| 38 | + dim_data=DIM_DATA, |
| 39 | + dim_cond=max(stop for _, stop in context_bonds), |
| 40 | + dim_cond_map=DIM_COND_MAP, |
| 41 | + embed_cond_equal=True, |
| 42 | + attention_pooling=attention_pooling, |
| 43 | + num_heads=4, |
| 44 | + dropout_rate=dropout_rate, |
| 45 | + context_entity_bonds=context_bonds, |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def _make_inputs(rng, dim_cond): |
| 50 | + rng_x, rng_c = jax.random.split(rng) |
| 51 | + x = jax.random.normal(rng_x, (BATCH_SIZE, DIM_DATA)) |
| 52 | + c = jax.random.normal(rng_c, (BATCH_SIZE, dim_cond)) |
| 53 | + return x, c |
| 54 | + |
| 55 | + |
| 56 | +class TestAttentionPooling: |
| 57 | + """Tests for attention pooling in ConditionalPerturbationNetwork.""" |
| 58 | + |
| 59 | + @pytest.mark.parametrize( |
| 60 | + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS |
| 61 | + ) |
| 62 | + def test_attention_pooling_forward_pass( |
| 63 | + self, context_bonds, dim_cond, num_contexts |
| 64 | + ): |
| 65 | + """Test that attention pooling produces correct output shape.""" |
| 66 | + model = _make_model(context_bonds, attention_pooling=True) |
| 67 | + rng = jax.random.PRNGKey(0) |
| 68 | + x, c = _make_inputs(rng, dim_cond) |
| 69 | + |
| 70 | + rng_params, rng_dropout = jax.random.split(rng) |
| 71 | + params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[ |
| 72 | + "params" |
| 73 | + ] |
| 74 | + |
| 75 | + out = model.apply({"params": params}, x, c, num_contexts) |
| 76 | + assert out.shape == (BATCH_SIZE, DIM_DATA) |
| 77 | + assert not jnp.allclose(out, 0.0) |
| 78 | + |
| 79 | + @pytest.mark.parametrize( |
| 80 | + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS |
| 81 | + ) |
| 82 | + def test_both_pooling_modes_same_output_shape( |
| 83 | + self, context_bonds, dim_cond, num_contexts |
| 84 | + ): |
| 85 | + """Test that mean and attention pooling produce the same output shape.""" |
| 86 | + rng = jax.random.PRNGKey(42) |
| 87 | + x, c = _make_inputs(rng, dim_cond) |
| 88 | + |
| 89 | + model_mean = _make_model(context_bonds, attention_pooling=False) |
| 90 | + rng_p1, rng_d1, rng_p2, rng_d2 = jax.random.split(rng, 4) |
| 91 | + params_mean = model_mean.init({"params": rng_p1, "dropout": rng_d1}, x=x, c=c)[ |
| 92 | + "params" |
| 93 | + ] |
| 94 | + out_mean = model_mean.apply({"params": params_mean}, x, c, num_contexts) |
| 95 | + |
| 96 | + model_attn = _make_model(context_bonds, attention_pooling=True) |
| 97 | + params_attn = model_attn.init({"params": rng_p2, "dropout": rng_d2}, x=x, c=c)[ |
| 98 | + "params" |
| 99 | + ] |
| 100 | + out_attn = model_attn.apply({"params": params_attn}, x, c, num_contexts) |
| 101 | + |
| 102 | + assert out_mean.shape == out_attn.shape == (BATCH_SIZE, DIM_DATA) |
| 103 | + |
| 104 | + @pytest.mark.parametrize( |
| 105 | + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS |
| 106 | + ) |
| 107 | + def test_dropout_deterministic_vs_stochastic( |
| 108 | + self, context_bonds, dim_cond, num_contexts |
| 109 | + ): |
| 110 | + """Test that deterministic=False produces different outputs across runs |
| 111 | + while deterministic=True is consistent.""" |
| 112 | + model = _make_model(context_bonds, attention_pooling=True, dropout_rate=0.5) |
| 113 | + rng = jax.random.PRNGKey(7) |
| 114 | + x, c = _make_inputs(rng, dim_cond) |
| 115 | + |
| 116 | + rng_params, rng_dropout = jax.random.split(rng) |
| 117 | + params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[ |
| 118 | + "params" |
| 119 | + ] |
| 120 | + |
| 121 | + # Deterministic mode: two calls should be identical |
| 122 | + out_eval_1 = model.apply( |
| 123 | + {"params": params}, x, c, num_contexts, deterministic=True |
| 124 | + ) |
| 125 | + out_eval_2 = model.apply( |
| 126 | + {"params": params}, x, c, num_contexts, deterministic=True |
| 127 | + ) |
| 128 | + assert jnp.allclose(out_eval_1, out_eval_2) |
| 129 | + |
| 130 | + # Stochastic mode: two calls with different dropout keys should differ |
| 131 | + key1, key2 = jax.random.split(jax.random.PRNGKey(99)) |
| 132 | + out_train_1 = model.apply( |
| 133 | + {"params": params}, |
| 134 | + x, |
| 135 | + c, |
| 136 | + num_contexts, |
| 137 | + deterministic=False, |
| 138 | + rngs={"dropout": key1}, |
| 139 | + ) |
| 140 | + out_train_2 = model.apply( |
| 141 | + {"params": params}, |
| 142 | + x, |
| 143 | + c, |
| 144 | + num_contexts, |
| 145 | + deterministic=False, |
| 146 | + rngs={"dropout": key2}, |
| 147 | + ) |
| 148 | + assert not jnp.allclose(out_train_1, out_train_2) |
0 commit comments