1212from torch import Tensor , nn
1313import torch
1414from torch_geometric .data import Data
15+ from omegaconf import DictConfig , OmegaConf
1516
1617# graphium imports
1718from graphium .data .utils import get_keys
@@ -592,6 +593,26 @@ def _check_bad_arguments(self):
592593 (self .in_dim_edges > 0 ) or (self .full_dims_edges is not None )
593594 ) and not self .layer_class .layer_supports_edges :
594595 raise ValueError (f"Cannot use edge features with class `{ self .layer_class } `" )
596+
597+ def get_nested_key (self , d , target_key ):
598+ """
599+ Get the value associated with a key in a nested dictionary.
600+
601+ Parameters:
602+ - d: The dictionary to search in
603+ - target_key: The key to search for
604+
605+ Returns:
606+ - The value associated with the key if found, None otherwise
607+ """
608+ if target_key in d :
609+ return d [target_key ]
610+ for key , value in d .items ():
611+ if isinstance (value , (dict , DictConfig )):
612+ nested_result = self .get_nested_key (value , target_key )
613+ if nested_result is not None :
614+ return nested_result
615+ return None
595616
596617 def _create_layers (self ):
597618 r"""
@@ -632,15 +653,18 @@ def _create_layers(self):
632653
633654 # Find the edge key-word arguments depending on the layer type and residual connection
634655 this_edge_kwargs = {}
656+ # import ipdb; ipdb.set_trace()
635657 if self .layer_class .layer_supports_edges and self .in_dim_edges > 0 :
636658 this_edge_kwargs ["in_dim_edges" ] = this_in_dim_edges
637659 if "out_dim_edges" in inspect .signature (self .layer_class .__init__ ).parameters .keys ():
638660 if self .full_dims_edges is not None :
639661 this_out_dim_edges = self .full_dims_edges [ii + 1 ]
640662 this_edge_kwargs ["out_dim_edges" ] = this_out_dim_edges
641663 else :
642- this_out_dim_edges = self .layer_kwargs .get ("out_dim_edges" )
664+ this_out_dim_edges = self .get_nested_key (self .layer_kwargs , "out_dim_edges" )
665+ this_edge_kwargs ["out_dim_edges" ] = this_out_dim_edges
643666 layer_out_dims_edges .append (this_out_dim_edges )
667+ # import ipdb; ipdb.set_trace()
644668
645669 # Create the GNN layer
646670 self .layers .append (
@@ -659,6 +683,7 @@ def _create_layers(self):
659683
660684 # Create the Virtual Node layer, except at the last layer
661685 if ii < len (residual_out_dims ):
686+ # import ipdb; ipdb.set_trace()
662687 self .virtual_node_layers .append (
663688 self .virtual_node_class (
664689 in_dim = this_out_dim * self .layers [- 1 ].out_dim_factor ,
0 commit comments