Skip to content

Commit 20434b6

Browse files
Fix to allow use edges for MPNN layer
1 parent 4db06f0 commit 20434b6

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

expts/hydra-configs/model/mpnn.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ architecture:
2222
attn_type: "none" # "full-attention", "none"
2323
# biased_attention: false
2424
attn_kwargs: null
25+
virtual_node: 'sum'
26+
use_virtual_edges: true

graphium/nn/architectures/global_architectures.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import Tensor, nn
1313
import torch
1414
from torch_geometric.data import Data
15+
from omegaconf import DictConfig, OmegaConf
1516

1617
# graphium imports
1718
from graphium.data.utils import get_keys
@@ -592,6 +593,26 @@ def _check_bad_arguments(self):
592593
(self.in_dim_edges > 0) or (self.full_dims_edges is not None)
593594
) and not self.layer_class.layer_supports_edges:
594595
raise ValueError(f"Cannot use edge features with class `{self.layer_class}`")
596+
597+
def get_nested_key(self, d, target_key):
598+
"""
599+
Get the value associated with a key in a nested dictionary.
600+
601+
Parameters:
602+
- d: The dictionary to search in
603+
- target_key: The key to search for
604+
605+
Returns:
606+
- The value associated with the key if found, None otherwise
607+
"""
608+
if target_key in d:
609+
return d[target_key]
610+
for key, value in d.items():
611+
if isinstance(value, (dict, DictConfig)):
612+
nested_result = self.get_nested_key(value, target_key)
613+
if nested_result is not None:
614+
return nested_result
615+
return None
595616

596617
def _create_layers(self):
597618
r"""
@@ -632,15 +653,18 @@ def _create_layers(self):
632653

633654
# Find the edge key-word arguments depending on the layer type and residual connection
634655
this_edge_kwargs = {}
656+
# import ipdb; ipdb.set_trace()
635657
if self.layer_class.layer_supports_edges and self.in_dim_edges > 0:
636658
this_edge_kwargs["in_dim_edges"] = this_in_dim_edges
637659
if "out_dim_edges" in inspect.signature(self.layer_class.__init__).parameters.keys():
638660
if self.full_dims_edges is not None:
639661
this_out_dim_edges = self.full_dims_edges[ii + 1]
640662
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
641663
else:
642-
this_out_dim_edges = self.layer_kwargs.get("out_dim_edges")
664+
this_out_dim_edges = self.get_nested_key(self.layer_kwargs, "out_dim_edges")
665+
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
643666
layer_out_dims_edges.append(this_out_dim_edges)
667+
# import ipdb; ipdb.set_trace()
644668

645669
# Create the GNN layer
646670
self.layers.append(
@@ -659,6 +683,7 @@ def _create_layers(self):
659683

660684
# Create the Virtual Node layer, except at the last layer
661685
if ii < len(residual_out_dims):
686+
# import ipdb; ipdb.set_trace()
662687
self.virtual_node_layers.append(
663688
self.virtual_node_class(
664689
in_dim=this_out_dim * self.layers[-1].out_dim_factor,

0 commit comments

Comments
 (0)