Skip to content

Commit f698df4

Browse files
authored
Merge pull request #494 from datamol-io/fix_mup_attn
Fix mup for the layers with AttentionLayerMup
2 parents 92281ba + 4045fcf commit f698df4

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

graphium/nn/architectures/global_architectures.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,12 @@ def _recursive_divide_dim(x: collections.abc.Mapping):
13381338
_recursive_divide_dim(v)
13391339
elif k in ["in_dim", "out_dim", "in_dim_edges", "out_dim_edges"]:
13401340
x[k] = round(v / divide_factor)
1341+
elif k in ["embed_dim"]:
1342+
num_heads = x.get("num_heads", 1)
1343+
x[k] = round(v / divide_factor)
1344+
assert (
1345+
x[k] % num_heads == 0
1346+
), f"embed_dim={x[k]} is not divisible by num_heads={num_heads}"
13411347

13421348
_recursive_divide_dim(kwargs["layer_kwargs"])
13431349

graphium/nn/pyg_layers/gps_pyg.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
precision: str = "32",
6969
biased_attention_key: Optional[str] = None,
7070
attn_kwargs=None,
71+
force_consistent_in_dim: bool = True,
7172
droppath_rate_attn: float = 0.0,
7273
droppath_rate_ffn: float = 0.0,
7374
hidden_dim_scaling: float = 4.0,
@@ -93,12 +94,6 @@ def __init__(
9394
out_dim:
9495
Output node feature dimensions of the layer
9596
96-
in_dim:
97-
Input edge feature dimensions of the layer
98-
99-
out_dim:
100-
Output edge feature dimensions of the layer
101-
10297
in_dim_edges:
10398
input edge-feature dimensions of the layer
10499
@@ -134,6 +129,11 @@ def __init__(
134129
attn_kwargs:
135130
kwargs for attention layer
136131
132+
force_consistent_in_dim:
133+
whether to force the `embed_dim` to be the same as the `in_dim` for the attention and mpnn.
134+
The argument is only valid if `attn_type` is not None. If `embed_dim` is not provided,
135+
it will be set to `in_dim` by default, so this parameter won't have an effect.
136+
137137
droppath_rate_attn:
138138
stochastic depth drop rate for attention layer https://arxiv.org/abs/1603.09382
139139
@@ -208,7 +208,9 @@ def __init__(
208208
self.biased_attention_key = biased_attention_key
209209
# Initialize the MPNN and Attention layers
210210
self.mpnn = self._parse_mpnn_layer(mpnn_type, mpnn_kwargs)
211-
self.attn_layer = self._parse_attn_layer(attn_type, self.biased_attention_key, attn_kwargs)
211+
self.attn_layer = self._parse_attn_layer(
212+
attn_type, self.biased_attention_key, attn_kwargs, force_consistent_in_dim=force_consistent_in_dim
213+
)
212214

213215
self.output_scale = output_scale
214216
self.use_edges = True if self.in_dim_edges is not None else False
@@ -251,8 +253,6 @@ def forward(self, batch: Batch) -> Batch:
251253
"""
252254
# pe, feat, edge_index, edge_feat = batch.pos_enc_feats_sign_flip, batch.feat, batch.edge_index, batch.edge_feat
253255
feat = batch.feat
254-
if self.use_edges:
255-
edges_feat_in = batch.edge_feat
256256

257257
feat_in = feat # for first residual connection
258258

@@ -323,26 +323,38 @@ def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[
323323
return mpnn_layer
324324

325325
def _parse_attn_layer(
326-
self, attn_type, biased_attention_key: str, attn_kwargs: Dict[str, Any]
326+
self,
327+
attn_type,
328+
biased_attention_key: str,
329+
attn_kwargs: Dict[str, Any],
330+
force_consistent_in_dim: bool = True,
327331
) -> Optional[Module]:
328332
"""
329333
parse the input attention layer and check if it is valid
330334
Parameters:
331335
attn_type: type of the attention layer
332336
biased_attention_key: key for the attenion bias
337+
attn_kwargs: kwargs for the attention layer
338+
force_consistent_in_dim: whether to force the `embed_dim` to be the same as the `in_dim`
339+
333340
Returns:
334341
attn_layer: the attention layer
335342
"""
336343

337344
# Set the default values for the Attention layer
338345
if attn_kwargs is None:
339346
attn_kwargs = {}
340-
attn_kwargs.setdefault("embed_dim", self.in_dim)
341347
attn_kwargs.setdefault("num_heads", 1)
342348
attn_kwargs.setdefault("dropout", self.dropout)
343349
attn_kwargs.setdefault("batch_first", True)
344350
self.attn_kwargs = attn_kwargs
345351

352+
# Force the `embed_dim` to be the same as the `in_dim`
353+
attn_kwargs.setdefault("embed_dim", self.in_dim)
354+
if force_consistent_in_dim:
355+
embed_dim = attn_kwargs["embed_dim"]
356+
assert embed_dim == self.in_dim, f"embed_dim={embed_dim} must be equal to in_dim={self.in_dim}"
357+
346358
# Initialize the Attention layer
347359
attn_layer, attn_class = None, None
348360
if attn_type is not None:

0 commit comments

Comments
 (0)