Skip to content

Commit 0b14293

Browse files
Feature/attention pooling (#33)
1 parent 4215af5 commit 0b14293

File tree

5 files changed

+214
-6
lines changed

5 files changed

+214
-6
lines changed

cmonge/models/nn.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,21 @@ class ConditionalPerturbationNetwork(BasePotential):
314314
embed_cond_equal: bool = (
315315
False # Whether all context variables should be treated as set or not
316316
)
317+
attention_pooling: bool = False
318+
num_heads: int = 4
319+
dropout_rate: float = 0.1
317320
context_entity_bonds: Iterable[Tuple[int, int]] = (
318321
(0, 10),
319322
(0, 11),
320323
) # Start/stop index per modality
321324

322325
@nn.compact
323326
def __call__(
324-
self, x: jnp.ndarray, c: jnp.ndarray, num_contexts: int = 2
327+
self,
328+
x: jnp.ndarray,
329+
c: jnp.ndarray,
330+
num_contexts: int = 2,
331+
deterministic: bool = True,
325332
) -> jnp.ndarray: # noqa: D102
326333
"""
327334
Args:
@@ -379,8 +386,45 @@ def __call__(
379386
)
380387
layer = nn.Dense(dim_cond_map[0], use_bias=True)
381388
embeddings = [self.act_fn(layer(context)) for context in contexts]
382-
# Average along stacked dimension (alternatives like summing are possible)
383-
cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0)
389+
390+
if self.attention_pooling:
391+
stacked_embeddings = jnp.stack(embeddings, axis=1) # (Batch, N, Dim)
392+
393+
# Input Dropout
394+
stacked_embeddings = nn.Dropout(
395+
rate=self.dropout_rate, deterministic=deterministic
396+
)(stacked_embeddings)
397+
398+
# Multi-Head Attention Scores
399+
att_layer = nn.Dense(
400+
self.num_heads, use_bias=True, name="AttentionScores"
401+
)
402+
scores = att_layer(stacked_embeddings) # (Batch, N, Heads)
403+
weights = jax.nn.softmax(scores, axis=1)
404+
405+
# Attention Weights Dropout
406+
weights = nn.Dropout(
407+
rate=self.dropout_rate, deterministic=deterministic
408+
)(weights)
409+
410+
# Weighted Pooling: (B, N, D), (B, N, H) -> (B, H, D)
411+
weighted_sum = jnp.einsum("bnd,bnh->bhd", stacked_embeddings, weights)
412+
413+
# Flatten and Project
414+
cond_embedding = weighted_sum.reshape(
415+
weighted_sum.shape[0], -1
416+
) # (B, H*D)
417+
cond_embedding = nn.Dense(
418+
dim_cond_map[0], use_bias=True, name="AttentionOutput"
419+
)(cond_embedding)
420+
421+
# Output Dropout
422+
cond_embedding = nn.Dropout(
423+
rate=self.dropout_rate, deterministic=deterministic
424+
)(cond_embedding)
425+
else:
426+
# Average along stacked dimension (alternatives like summing are possible)
427+
cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0)
384428

385429
z = jnp.concatenate((x, cond_embedding), axis=1)
386430
if self.layer_norm:
@@ -403,7 +447,12 @@ def create_train_state(
403447
"""Create initial `TrainState`."""
404448
c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim)
405449
x = jnp.ones((1, self.dim_data)) # (n_batch, data_dim)
406-
params = self.init(rng, x=x, c=c)["params"]
450+
451+
# Split rng for dropout keys during init
452+
rng, rng_dropout = jax.random.split(rng)
453+
init_rngs = {"params": rng, "dropout": rng_dropout}
454+
455+
params = self.init(init_rngs, x=x, c=c)["params"]
407456
return PotentialTrainState.create(
408457
apply_fn=self.apply,
409458
params=params,

cmonge/tests/models/__init__.py

Whitespace-only changes.
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import pytest
4+
5+
from cmonge.models.nn import ConditionalPerturbationNetwork
6+
7+
# (context_bonds, dim_cond, num_contexts)
8+
CONTEXT_BOND_CONFIGS = [
9+
pytest.param(
10+
((0, 10), (10, 20)),
11+
20,
12+
2,
13+
id="non_overlapping_2_modalities",
14+
),
15+
pytest.param(
16+
((0, 10), (0, 10)),
17+
10,
18+
2,
19+
id="overlapping_2_modalities",
20+
),
21+
pytest.param(
22+
((0, 10), (10, 20), (20, 30)),
23+
30,
24+
3,
25+
id="non_overlapping_3_modalities",
26+
),
27+
]
28+
29+
DIM_DATA = 16
30+
DIM_HIDDEN = [32, 32]
31+
DIM_COND_MAP = (8,)
32+
BATCH_SIZE = 4
33+
34+
35+
def _make_model(context_bonds, attention_pooling, dropout_rate=0.1):
36+
return ConditionalPerturbationNetwork(
37+
dim_hidden=DIM_HIDDEN,
38+
dim_data=DIM_DATA,
39+
dim_cond=max(stop for _, stop in context_bonds),
40+
dim_cond_map=DIM_COND_MAP,
41+
embed_cond_equal=True,
42+
attention_pooling=attention_pooling,
43+
num_heads=4,
44+
dropout_rate=dropout_rate,
45+
context_entity_bonds=context_bonds,
46+
)
47+
48+
49+
def _make_inputs(rng, dim_cond):
50+
rng_x, rng_c = jax.random.split(rng)
51+
x = jax.random.normal(rng_x, (BATCH_SIZE, DIM_DATA))
52+
c = jax.random.normal(rng_c, (BATCH_SIZE, dim_cond))
53+
return x, c
54+
55+
56+
class TestAttentionPooling:
57+
"""Tests for attention pooling in ConditionalPerturbationNetwork."""
58+
59+
@pytest.mark.parametrize(
60+
"context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS
61+
)
62+
def test_attention_pooling_forward_pass(
63+
self, context_bonds, dim_cond, num_contexts
64+
):
65+
"""Test that attention pooling produces correct output shape."""
66+
model = _make_model(context_bonds, attention_pooling=True)
67+
rng = jax.random.PRNGKey(0)
68+
x, c = _make_inputs(rng, dim_cond)
69+
70+
rng_params, rng_dropout = jax.random.split(rng)
71+
params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[
72+
"params"
73+
]
74+
75+
out = model.apply({"params": params}, x, c, num_contexts)
76+
assert out.shape == (BATCH_SIZE, DIM_DATA)
77+
assert not jnp.allclose(out, 0.0)
78+
79+
@pytest.mark.parametrize(
80+
"context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS
81+
)
82+
def test_both_pooling_modes_same_output_shape(
83+
self, context_bonds, dim_cond, num_contexts
84+
):
85+
"""Test that mean and attention pooling produce the same output shape."""
86+
rng = jax.random.PRNGKey(42)
87+
x, c = _make_inputs(rng, dim_cond)
88+
89+
model_mean = _make_model(context_bonds, attention_pooling=False)
90+
rng_p1, rng_d1, rng_p2, rng_d2 = jax.random.split(rng, 4)
91+
params_mean = model_mean.init({"params": rng_p1, "dropout": rng_d1}, x=x, c=c)[
92+
"params"
93+
]
94+
out_mean = model_mean.apply({"params": params_mean}, x, c, num_contexts)
95+
96+
model_attn = _make_model(context_bonds, attention_pooling=True)
97+
params_attn = model_attn.init({"params": rng_p2, "dropout": rng_d2}, x=x, c=c)[
98+
"params"
99+
]
100+
out_attn = model_attn.apply({"params": params_attn}, x, c, num_contexts)
101+
102+
assert out_mean.shape == out_attn.shape == (BATCH_SIZE, DIM_DATA)
103+
104+
@pytest.mark.parametrize(
105+
"context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS
106+
)
107+
def test_dropout_deterministic_vs_stochastic(
108+
self, context_bonds, dim_cond, num_contexts
109+
):
110+
"""Test that deterministic=False produces different outputs across runs
111+
while deterministic=True is consistent."""
112+
model = _make_model(context_bonds, attention_pooling=True, dropout_rate=0.5)
113+
rng = jax.random.PRNGKey(7)
114+
x, c = _make_inputs(rng, dim_cond)
115+
116+
rng_params, rng_dropout = jax.random.split(rng)
117+
params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[
118+
"params"
119+
]
120+
121+
# Deterministic mode: two calls should be identical
122+
out_eval_1 = model.apply(
123+
{"params": params}, x, c, num_contexts, deterministic=True
124+
)
125+
out_eval_2 = model.apply(
126+
{"params": params}, x, c, num_contexts, deterministic=True
127+
)
128+
assert jnp.allclose(out_eval_1, out_eval_2)
129+
130+
# Stochastic mode: two calls with different dropout keys should differ
131+
key1, key2 = jax.random.split(jax.random.PRNGKey(99))
132+
out_train_1 = model.apply(
133+
{"params": params},
134+
x,
135+
c,
136+
num_contexts,
137+
deterministic=False,
138+
rngs={"dropout": key1},
139+
)
140+
out_train_2 = model.apply(
141+
{"params": params},
142+
x,
143+
c,
144+
num_contexts,
145+
deterministic=False,
146+
rngs={"dropout": key2},
147+
)
148+
assert not jnp.allclose(out_train_1, out_train_2)

cmonge/trainers/conditional_monge_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,13 @@ def train(self, datamodule: ConditionalDataModule):
149149
else self.generate_batch(datamodule, "valid")
150150
)
151151

152+
self.key, step_key = jax.random.split(self.key)
153+
152154
self.state_neural_net, grads, current_logs = self.step_fn(
153155
self.state_neural_net,
154156
grads=grads,
155157
train_batch=train_batch,
158+
dropout_key=step_key,
156159
valid_batch=valid_batch,
157160
is_logging_step=is_logging_step,
158161
is_gradient_acc_step=is_gradient_acc_step,
@@ -176,14 +179,20 @@ def loss_fn(
176179
apply_fn: Callable,
177180
batch: Dict[str, jnp.ndarray],
178181
n_contexts: int,
182+
dropout_key: Optional[jnp.ndarray] = None,
179183
) -> Tuple[float, Dict[str, float]]:
180184
"""Loss function."""
181185
# map samples with the fitted map
186+
kwargs = {}
187+
if dropout_key is not None:
188+
kwargs = {"deterministic": False, "rngs": {"dropout": dropout_key}}
189+
182190
mapped_samples = apply_fn(
183191
{"params": params},
184192
batch["source"],
185193
batch["condition"],
186194
n_contexts,
195+
**kwargs,
187196
)
188197

189198
# compute the loss
@@ -200,11 +209,12 @@ def loss_fn(
200209

201210
return val_tot_loss, loss_logs
202211

203-
@functools.partial(jax.jit, static_argnums=[4, 5, 6, 7])
212+
@functools.partial(jax.jit, static_argnums=[5, 6, 7, 8])
204213
def step_fn(
205214
state_neural_net: train_state.TrainState,
206215
grads: frozen_dict.FrozenDict,
207216
train_batch: Dict[str, jnp.ndarray],
217+
dropout_key: jnp.ndarray,
208218
valid_batch: Optional[Dict[str, jnp.ndarray]] = None,
209219
is_logging_step: bool = False,
210220
is_gradient_acc_step: bool = False,
@@ -219,6 +229,7 @@ def step_fn(
219229
state_neural_net.apply_fn,
220230
train_batch,
221231
n_train_contexts,
232+
dropout_key,
222233
)
223234
# Accumulate gradients
224235
grads = tree_map(lambda g, step_g: g + step_g, grads, step_grads)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "cmonge"
3-
version = "0.1.2"
3+
version = "0.1.3"
44
description = "Extension of the Monge Gap to learn conditional optimal transport maps"
55
authors = ["Alice Driessen <adr@zurich.ibm.com>", "Benedek Harsanyi <hben.0204@gmail.com>", "Jannis Born <jab@zurich.ibm.com>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)