Skip to content

Commit 88cc71b

Browse files
committed
Updated Design.md and refered to other files for detailed docs
1 parent 9a9b1e3 commit 88cc71b

File tree

4 files changed

+42
-74
lines changed

4 files changed

+42
-74
lines changed

docs/baseline.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,10 @@ One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate fr
2323
| **Tox21** | GCN | 0.202 ± 0.005 | 0.773 ± 0.006 | 0.334 ± 0.03 | **0.176 ± 0.001** | **0.850 ± 0.006** | 0.446 ± 0.01 |
2424
| | GIN | 0.200 ± 0.002 | 0.789 ± 0.009 | 0.350 ± 0.01 | 0.176 ± 0.001 | 0.841 ± 0.005 | 0.454 ± 0.009 |
2525
| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | **0.455 ± 0.008** |
26+
27+
# LargeMix Baseline
28+
Coming soon!
29+
30+
# UltraLarge Baseline
31+
Coming soon!
32+

docs/contribute.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ You can build and serve the documentation locally with:
3939

4040
```bash
4141
# Build and serve the doc
42-
make serve
42+
mkdocs serve
4343
```

docs/design.md

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,102 +2,63 @@
22

33
---
44

5-
### Diagram for data processing in molGPS.
6-
7-
<img src="images/datamodule.png" alt= "Data Processing Chart" width="100%" height="100%">
8-
9-
10-
11-
### Diagram for Muti-task network in molGPS
12-
13-
<img src="images/full_graph_network.png" alt= "Full Graph Multi-task Network" width="100%" height="100%">
145

6+
The library is designed with 3 things in mind:
7+
- High modularity and configurability with *YAML* files
8+
- Contain the state-of-the art GNNs, including positional encodings and graph Transformers
9+
- Massively multitasking across diverse and sparse datasets
1510

11+
The current page will walk you through the different aspects of the design that enable that.
1612

13+
### Diagram for data processing in Graphium.
1714

15+
First, when working with molecules, there are tons of options regarding atomic and bond featurisation that can be extracted from the periodic table, from empirical results, or from simulated 3D structures.
1816

17+
Second, when working with graph Transformers, there are plenty of options regarding the positional and structural encodings (PSE) that are fundamental in driving the accuracy and the generalization of the models.
1918

20-
**Section from the previous README:**
19+
With this in mind, we propose a very versatile chemical and PSE encoding, alongside an encoder manager, that can be fully configured from the yaml files. The idea is to assign matching *input keys* to both the features and the encoders, then pool the outputs according to the *output keys*. It is better summarized in the image below.
2120

22-
### Data setup
23-
24-
Then, you need to download the data needed to run the code. Right now, we have 2 sets of data folders, present in the link [here](https://drive.google.com/drive/folders/1RrbNZkEE2rf41_iroa1LbIyegW00h3Ql?usp=sharing).
21+
<img src="images/datamodule.png" alt= "Data Processing Chart" width="100%" height="100%">
2522

26-
- **micro_ZINC** (Synthetic dataset)
27-
- A small subset (1000 mols) of the ZINC dataset
28-
- The score is the subtraction of the computed LogP and the synthetic accessibility score SA
29-
- The data must be downloaded to the folder `./graphium/data/micro_ZINC/`
3023

31-
- **ZINC_bench_gnn** (Synthetic dataset)
32-
- A subset (12000 mols) of the ZINC dataset
33-
- The score is the subtraction of the computed LogP and the synthetic accessibility score SA
34-
- These are the same 12k molecules provided by the [Benchmarking-gnn](https://github.com/graphdeeplearning/benchmarking-gnns) repository.
35-
- We provide the pre-processed graphs in `ZINC_bench_gnn/data_from_benchmark`
36-
- We provide the SMILES in `ZINC_bench_gnn/smiles_score.csv`, with the train-val-test indexes in the file `indexes_train_val_test.csv`.
37-
- The first 10k elements are the training set
38-
- The next 1k the valid set
39-
- The last 1k the test set.
40-
- The data must be downloaded to the folder `./graphium/data/ZINC_bench_gnn/`
4124

42-
Then, you can run the main file to make sure that all the dependancies are correctly installed and that the code works as expected.
25+
### Diagram for Muti-task network in Graphium
4326

44-
```bash
45-
python expts/main_micro_zinc.py
46-
```
27+
As mentioned, we want to be able to pperform massive multi-tasking to enable us to work across a huge diversity of datasets. The idea is to use a combination of multiple task-heads, where a different MLP is applied to each task predictions. However, it is also designed such that each task can have as many labels as desired, thus enabling to group labels together according to whether they should share weights/losses.
4728

48-
---
29+
The design is better explained in the image below. Notice how the *keys* outputed by GraphDict are used differently across the different GNN layers.
4930

50-
**TODO: explain the internal design of Graphium so people can contribute to it more easily.**
31+
<img src="images/full_graph_network.png" alt= "Full Graph Multi-task Network" width="100%" height="100%">
5132

5233
## Structure of the code
5334

5435
The code is built to rapidly iterate on different architectures of neural networks (NN) and graph neural networks (GNN) with Pytorch. The main focus of this work is molecular tasks, and we use the package `rdkit` to transform molecular SMILES into graphs.
5536

56-
### data_parser
57-
58-
This folder contains tools that allow tdependenciesrent kind of molecular data files, such as `.csv` or `.xlsx` with SMILES data, or `.sdf` files with 3D data.
59-
60-
61-
### features
62-
63-
Different utilities for molecules, such as Smiles to adjacency graph transformer, molecular property extraction, atomic properties, bond properties, ...
64-
65-
**_The MolecularTransformer and AdjGraphTransformer come from ivbase, but I don't like them. I think we should replace them with something simpler and give more flexibility for combining one-hot embedding with physical properties embedding._**.
66-
67-
### trainer
68-
69-
The trainer contains the interface to the `pytorch-lightning` library, with `PredictorModule` being the main class used for any NN model, either for regression or classification. It also contains some modifications to the logger from `pytorch-lightning` to enable more flexibility.
70-
71-
### utils
72-
73-
Any kind of utilities that can be used anywhere, including argument checkers and configuration loader
74-
75-
### visualization
76-
77-
Plot visualization tools
78-
79-
## Modifying the code
80-
81-
### Adding a new GNN layer
82-
83-
Any new GNN layer must inherit from the class `graphium.nn.base_graph_layer.BaseGraphLayer` and be implemented in the folder `graphium/nn/pyg_layers`, imported in the file `graphium/nn/architectures.py`, and in the same file, added to the function `FeedForwardGraph._parse_gnn_layer`.
84-
85-
To be used in the configuration file as a `graphium.model.layer_name`, it must also be implemented with some variable parameters in the file `expts/config_gnns.yaml`.
37+
Below are a list of directory and their respective documentations:
8638

87-
### Adding a new NN architecture
39+
- cli
40+
- [config](https://github.com/datamol-io/graphium/blob/main/graphium/config/README.md)
41+
- [data](https://github.com/datamol-io/graphium/blob/main/graphium/data/README.md)
42+
- [features](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)
43+
- finetuning
44+
- [ipu](https://github.com/datamol-io/graphium/tree/main/graphium/ipu/README.md)
45+
- [nn](https://github.com/datamol-io/graphium/tree/main/graphium/nn/README.md)
46+
- [trainer](https://github.com/datamol-io/graphium/tree/main/graphium/trainer/README.md)
47+
- [utils](https://github.com/datamol-io/graphium/tree/main/graphium/features/README.md)
48+
- [visualization](https://github.com/datamol-io/graphium/tree/main/graphium/visualization/README.md)
8849

89-
All NN and GNN architectures compatible with the `pyg` library are provided in the file `graphium/nn/global_architectures.py`. When implementing a new architecture, it is highly recommended to inherit from `graphium.nn.architectures.FeedForwardNN` for regular neural networks, from `graphium.nn.global_architectures.FeedForwardGraph` for pyg neural network, or from any of their sub-classes.
9050

91-
### Changing the PredictorModule and loss function
51+
## Structure of the configs
9252

93-
The `PredictorModule` is a general pytorch-lightning module that should work with any kind of `pytorch.nn.Module` or `pl.LightningModule`. The class defines a structure of including models, loss functions, batch sizes, collate functions, metrics...
53+
Making the library very modular requires to have configuration files that have >200 lines, which becomes intractable, especially when we only want to have minor changes between configurations.
9454

95-
Some loss functions are already implemented in the PredictorModule, including `mse, bce, mae, cosine`, but some tasks will require more complex loss functions. One can add any new function in `graphium.trainer.predictor.PredictorModule._parse_loss_fun`.
55+
Hence, we use [hydra](https://hydra.cc/docs/intro/) to enable splitting the configurations into smaller and composable configuration files.
9656

97-
### Changing the metrics used
57+
Examples of possibilities include:
9858

99-
**_!WARNING! The metrics implementation was done for pytorch-lightning v0.8. There has been major changes to how the metrics are used and defined, so the whole implementation must change._**
59+
- Switching between accelerators (CPU, GPU and IPU)
60+
- Benchmarking different models on the same dataset
61+
- Fine-tuning a pre-trained model on a new dataset
10062

101-
Our current code is compatible with the metrics defined by _pytorch-lightning_, which include a great set of metrics. We also added the PearsonR and SpearmanR as they are important correlation metrics. You can define any new metric in the file `graphium/trainer/metrics.py`. The metric must inherit from `TensorMetric` and must be added to the dictionary `graphium.trainer.metrics.METRICS_DICT`.
63+
[In this document](https://github.com/datamol-io/graphium/tree/main/expts/hydra-configs#readme), we describe in details how each of the above functionality is achieved and how users can benefit from this design to achieve the most with Graphium with as little configuration as possible.
10264

103-
To use the metric, you can easily add it's name from `METRICS_DICT` in the yaml configuration file, at the address `metrics.metrics_dict`. Each metric has an underlying dictionnary with a mandatory `threshold` key containing information on how to threshold the prediction/target before computing the metric. Any `kwargs` arguments of the metric must also be added.

graphium/nn/encoders/mlp_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(
166166
self,
167167
input_keys: List[str],
168168
output_keys: str,
169-
in_dim: int or List[int],
169+
in_dim: Union[int, List[int]],
170170
hidden_dim: int,
171171
out_dim: int,
172172
num_layers: int,

0 commit comments

Comments
 (0)