@@ -68,6 +68,7 @@ def __init__(
6868 precision : str = "32" ,
6969 biased_attention_key : Optional [str ] = None ,
7070 attn_kwargs = None ,
71+ force_consistent_in_dim : bool = True ,
7172 droppath_rate_attn : float = 0.0 ,
7273 droppath_rate_ffn : float = 0.0 ,
7374 hidden_dim_scaling : float = 4.0 ,
@@ -93,12 +94,6 @@ def __init__(
9394 out_dim:
9495 Output node feature dimensions of the layer
9596
96- in_dim:
97- Input edge feature dimensions of the layer
98-
99- out_dim:
100- Output edge feature dimensions of the layer
101-
10297 in_dim_edges:
10398 input edge-feature dimensions of the layer
10499
@@ -134,6 +129,11 @@ def __init__(
134129 attn_kwargs:
135130 kwargs for attention layer
136131
132+ force_consistent_in_dim:
133+ whether to force the `embed_dim` to be the same as the `in_dim` for the attention and mpnn.
134+ The argument is only valid if `attn_type` is not None. If `embed_dim` is not provided,
135+ it will be set to `in_dim` by default, so this parameter won't have an effect.
136+
137137 droppath_rate_attn:
138138 stochastic depth drop rate for attention layer https://arxiv.org/abs/1603.09382
139139
@@ -208,7 +208,9 @@ def __init__(
208208 self .biased_attention_key = biased_attention_key
209209 # Initialize the MPNN and Attention layers
210210 self .mpnn = self ._parse_mpnn_layer (mpnn_type , mpnn_kwargs )
211- self .attn_layer = self ._parse_attn_layer (attn_type , self .biased_attention_key , attn_kwargs )
211+ self .attn_layer = self ._parse_attn_layer (
212+ attn_type , self .biased_attention_key , attn_kwargs , force_consistent_in_dim = force_consistent_in_dim
213+ )
212214
213215 self .output_scale = output_scale
214216 self .use_edges = True if self .in_dim_edges is not None else False
@@ -251,8 +253,6 @@ def forward(self, batch: Batch) -> Batch:
251253 """
252254 # pe, feat, edge_index, edge_feat = batch.pos_enc_feats_sign_flip, batch.feat, batch.edge_index, batch.edge_feat
253255 feat = batch .feat
254- if self .use_edges :
255- edges_feat_in = batch .edge_feat
256256
257257 feat_in = feat # for first residual connection
258258
@@ -323,26 +323,38 @@ def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[
323323 return mpnn_layer
324324
325325 def _parse_attn_layer (
326- self , attn_type , biased_attention_key : str , attn_kwargs : Dict [str , Any ]
326+ self ,
327+ attn_type ,
328+ biased_attention_key : str ,
329+ attn_kwargs : Dict [str , Any ],
330+ force_consistent_in_dim : bool = True ,
327331 ) -> Optional [Module ]:
328332 """
329333 parse the input attention layer and check if it is valid
330334 Parameters:
331335 attn_type: type of the attention layer
332336 biased_attention_key: key for the attenion bias
337+ attn_kwargs: kwargs for the attention layer
338+ force_consistent_in_dim: whether to force the `embed_dim` to be the same as the `in_dim`
339+
333340 Returns:
334341 attn_layer: the attention layer
335342 """
336343
337344 # Set the default values for the Attention layer
338345 if attn_kwargs is None :
339346 attn_kwargs = {}
340- attn_kwargs .setdefault ("embed_dim" , self .in_dim )
341347 attn_kwargs .setdefault ("num_heads" , 1 )
342348 attn_kwargs .setdefault ("dropout" , self .dropout )
343349 attn_kwargs .setdefault ("batch_first" , True )
344350 self .attn_kwargs = attn_kwargs
345351
352+ # Force the `embed_dim` to be the same as the `in_dim`
353+ attn_kwargs .setdefault ("embed_dim" , self .in_dim )
354+ if force_consistent_in_dim :
355+ embed_dim = attn_kwargs ["embed_dim" ]
356+ assert embed_dim == self .in_dim , f"embed_dim={ embed_dim } must be equal to in_dim={ self .in_dim } "
357+
346358 # Initialize the Attention layer
347359 attn_layer , attn_class = None , None
348360 if attn_type is not None :
0 commit comments