Skip to content

Commit 41a1172

Browse files
Merge pull request #457 from datamol-io/pipeline_integration
Pipeline integration + Virtual Nodes Edges bug fix
2 parents 8be0f9f + daf011c commit 41a1172

File tree

5 files changed

+53
-1
lines changed

5 files changed

+53
-1
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
type: ipu
2+
ipu_config:
3+
- deviceIterations(60) # IPU would require large batches to be ready for the model.
4+
# 60 for PCQM4mv2
5+
# 30 for largemix
6+
- replicationFactor(4)
7+
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
8+
# - enableExecutableCaching("pop_compiler_cache")
9+
- TensorLocations.numIOTiles(128)
10+
- _Popart.set("defaultBufferingDepth", 96)
11+
- Precision.enableStochasticRounding(True)
12+
13+
ipu_inference_config:
14+
# set device iteration and replication factor to 1 during inference
15+
# gradient accumulation was set to 1 in the code
16+
- deviceIterations(60)
17+
- replicationFactor(1)
18+
- Precision.enableStochasticRounding(False)
19+
20+
accelerator_kwargs:
21+
_accelerator: "ipu"
22+
gnn_layers_per_ipu: [4, 4, 4, 4]

expts/hydra-configs/model/mpnn.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ architecture:
2222
attn_type: "none" # "full-attention", "none"
2323
# biased_attention: false
2424
attn_kwargs: null
25+
virtual_node: 'sum'
26+
use_virtual_edges: true

graphium/config/_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def load_architecture(
260260
graph_output_nn_kwargs=graph_output_nn_kwargs,
261261
task_heads_kwargs=task_heads_kwargs,
262262
)
263+
# Get accelerator_kwargs if they exist
264+
accelerator_kwargs = config["accelerator"].get("accelerator_kwargs", None)
265+
if accelerator_kwargs is not None:
266+
model_kwargs["accelerator_kwargs"] = accelerator_kwargs
263267

264268
if model_class is FullGraphFinetuningNetwork:
265269
finetuning_head_kwargs = config["finetuning"].pop("finetuning_head", None)

graphium/config/zinc_default_multitask_pyg.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,5 @@ architecture: # The parameters for the full graph network are taken from `co
181181
dropout: 0.2
182182
normalization: none
183183
residual_type: none
184+
accelerator:
185+
type: cpu

graphium/nn/architectures/global_architectures.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import Tensor, nn
1313
import torch
1414
from torch_geometric.data import Data
15+
from omegaconf import DictConfig, OmegaConf
1516

1617
# graphium imports
1718
from graphium.data.utils import get_keys
@@ -593,6 +594,26 @@ def _check_bad_arguments(self):
593594
) and not self.layer_class.layer_supports_edges:
594595
raise ValueError(f"Cannot use edge features with class `{self.layer_class}`")
595596

597+
def get_nested_key(self, d, target_key):
598+
"""
599+
Get the value associated with a key in a nested dictionary.
600+
601+
Parameters:
602+
- d: The dictionary to search in
603+
- target_key: The key to search for
604+
605+
Returns:
606+
- The value associated with the key if found, None otherwise
607+
"""
608+
if target_key in d:
609+
return d[target_key]
610+
for key, value in d.items():
611+
if isinstance(value, (dict, DictConfig)):
612+
nested_result = self.get_nested_key(value, target_key)
613+
if nested_result is not None:
614+
return nested_result
615+
return None
616+
596617
def _create_layers(self):
597618
r"""
598619
Create all the necessary layers for the network.
@@ -639,7 +660,8 @@ def _create_layers(self):
639660
this_out_dim_edges = self.full_dims_edges[ii + 1]
640661
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
641662
else:
642-
this_out_dim_edges = self.layer_kwargs.get("out_dim_edges")
663+
this_out_dim_edges = self.get_nested_key(self.layer_kwargs, "out_dim_edges")
664+
this_edge_kwargs["out_dim_edges"] = this_out_dim_edges
643665
layer_out_dims_edges.append(this_out_dim_edges)
644666

645667
# Create the GNN layer

0 commit comments

Comments
 (0)