Skip to content

Migrating MPNN tutorial to Keras3#2289

Open
LakshmiKalaKadali wants to merge 3 commits intokeras-team:masterfrom
LakshmiKalaKadali:mpnn_br
Open

Migrating MPNN tutorial to Keras3#2289
LakshmiKalaKadali wants to merge 3 commits intokeras-team:masterfrom
LakshmiKalaKadali:mpnn_br

Conversation

@LakshmiKalaKadali
Copy link
Contributor

This PR migrates the Message Passing Neural Network (MPNN) molecular property prediction tutorial from Keras 2 to Keras 3. Replaced the legacy PartitionPadding and Transformer readout with a GatedReadout layer. This provides a more stable attention mechanism for small datasets like BBBP. gist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @LakshmiKalaKadali, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly updates the Message Passing Neural Network (MPNN) molecular property prediction tutorial to be fully compatible with Keras 3. The core changes involve adapting the graph representation to be backend-agnostic using fixed-size buffers and masks, and overhauling the data loading process with keras.utils.PyDataset. A key architectural improvement is the replacement of the previous readout mechanism with a more stable GatedReadout layer. Additionally, training stability and performance are enhanced through the adoption of AdamW with a CosineDecay learning rate schedule and the introduction of pre-featurization to reduce CPU bottlenecks.

Highlights

  • Keras 3 Migration: The tutorial has been fully migrated to Keras 3, ensuring compatibility with JAX, PyTorch, and TensorFlow backends. This involved updating imports, API calls, and data handling mechanisms.
  • Backend-Agnostic Graph Representation: The graph generation process (smiles_to_graph) was refactored to use fixed-size NumPy arrays and masks instead of TensorFlow's RaggedTensor. This ensures static shapes required for JAX and PyTorch compatibility and introduces pre-featurization for improved training efficiency.
  • Updated Readout Mechanism: The legacy PartitionPadding and TransformerEncoderReadout layers have been replaced with a GatedReadout layer. This new readout combines mean and max pooling with a gating mechanism, providing a more stable attention mechanism, especially for smaller datasets like BBBP.
  • Optimized Data Loading: The data loading pipeline was re-implemented using keras.utils.PyDataset, replacing the tf.data.Dataset approach. This custom PyDataset handles batching, padding, and the creation of molecule_indicator and mask tensors for the new graph representation.
  • Enhanced Model Training: The model now utilizes the AdamW optimizer with a CosineDecay learning rate schedule, along with weight_decay and global_clipnorm for improved training stability and performance. LayerNormalization was also added within the MessagePassing layer.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/graph/mpnn-molecular-graphs.py
    • Added Keras 3 conversion attribution to the file header.
    • Updated rdkit and graphviz installation commands in the tutorial text.
    • Modified imports to use keras and keras.ops instead of tensorflow and tensorflow.keras.
    • Removed TF_CPP_MIN_LOG_LEVEL environment variable and IPythonConsole import.
    • Introduced KERAS_BACKEND environment variable for Keras 3 backend selection.
    • Added global configuration constants for MAX_ATOMS, MAX_BONDS, BATCH_SIZE, EPOCHS, and LEARNING_RATE.
    • Replaced np.random.seed and tf.random.set_seed with keras.utils.set_random_seed.
    • Changed np.zeros dtype to float32 in Featurizer.encode and BondFeaturizer.encode.
    • Removed redundant __init__ method from AtomFeaturizer.
    • Updated atom_featurizer and bond_featurizer instantiation to remove allowable_sets keyword.
    • Refactored graph generation: Replaced graph_from_molecule and graphs_from_smiles with a new smiles_to_graph function.
    • Implemented fixed-size buffers and a mask in smiles_to_graph for Keras 3 backend compatibility.
    • Introduced a pre-featurization step for the entire dataset to improve training efficiency.
    • Replaced tf.data.Dataset and prepare_batch with a custom MPNNDataset class inheriting from keras.utils.PyDataset.
    • Updated dataset splitting logic to use the new MPNNDataset.
    • Modified EdgeNetwork to use keras.ops for tensor operations and adjusted num_segments for batch processing.
    • Updated MessagePassing layer to use keras.ops, LayerNormalization, and simplified padding logic.
    • Replaced PartitionPadding and TransformerEncoderReadout layers with a new GatedReadout layer.
    • Updated MPNNModel definition to accept mask input and incorporate the GatedReadout layer.
    • Added kernel_regularizer and Dropout to the final dense layers in MPNNModel.
    • Configured the model compilation with AdamW optimizer, CosineDecay learning rate schedule, weight_decay, and global_clipnorm.
    • Adjusted training epochs to use the global EPOCHS constant and verbose level.
    • Modified plotting code to save the AUC plot to a file.
    • Updated prediction logic to use ops.convert_to_numpy for y_pred.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does an excellent job of migrating the Message Passing Neural Network (MPNN) tutorial from Keras 2 to Keras 3, making it backend-agnostic. The changes significantly improve the code quality by introducing a more stable GatedReadout layer, optimizing the data pipeline with pre-featurization, and adopting modern Keras 3 practices like keras.ops and PyDataset. The model definition and training configuration have also been enhanced. I've identified a couple of areas for improvement: some redundant code has been introduced during the refactoring, and there's a minor formatting issue in a docstring. Overall, this is a high-quality contribution.

@LakshmiKalaKadali LakshmiKalaKadali changed the title Migratimg MPNN tutorial to Keras3 Migrating MPNN tutorial to Keras3 Feb 17, 2026
Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left the review comments.

Below are some of the comments from the exisiting.

  1. Remove unnecessary underscores from "_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
    implement an MPNN to predict a molecular property known as
    blood-brain barrier permeability (BBBP)." section.

Comment on lines +239 to +242
1. `molecule_from_smiles`: This takes a SMILES string as input and returns an RDKit molecule object. This process remains handled by RDKit on the CPU.
2. `smiles_to_graph`: This takes a SMILES string and returns a graph represented as a four-tuple: (atom_features, bond_features, pair_indices, mask).
The original implementation utilized tf.RaggedTensor, which is exclusive to TensorFlow. To remain backend-agnostic and support JAX and PyTorch, we now use fixed-size buffers (MAX_ATOMS and MAX_BONDS). We also introduce a mask—a boolean array that allows the model to distinguish between valid chemical data and zero-padding.
Finally, implemented a pre-featurization step. Instead of featurizing during the training loop (which creates a CPU bottleneck), we process all SMILES once and store them in a list of NumPy arrays. This allows the GPU backends to run at 100% efficiency.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restrict the line length to 80 characters for better readability, also, add extra line space after each item

"""
### Test the functions

We can now inspect a sample molecule and its corresponding graph representation. Note that the output shapes are now constant (e.g., 70 atoms and 150 bonds), ensuring compatibility across all Keras 3 backends.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limit 80 char per line

Comment on lines +386 to +390
a = np.zeros((self.batch_size, MAX_ATOMS, atom_featurizer.dim), dtype="float32")
b = np.zeros((self.batch_size, MAX_BONDS, bond_featurizer.dim), dtype="float32")
p = np.zeros((self.batch_size, MAX_BONDS, 2), dtype="int32")
m = np.zeros((self.batch_size, MAX_ATOMS), dtype="float32")
y = np.zeros((self.batch_size, 1), dtype="float32")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid char level variable name, always set proper variable name.

Comment on lines +457 to 458
The Message Passing Neural Network (MPNN) architecture implemented in this tutorial consists of three stages: message passing, readout, and classification. The message passing step is the core of the model, enabling information to flow through the molecular graph. It consists of two main components:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 char length per line

Comment on lines +528 to 534
In this tutorial, we utilize a Gated Readout combined with Hybrid Pooling (Mean and Max).
This approach is highly stable and fully compatible with JAX, PyTorch, and TensorFlow. The process works as follows:
Gating Mechanism: Each node state passes through a learned gating function (using sigmoid and tanh activations). This allows the model to "decide" which atoms are most important for the molecular property being predicted.
Masking: We use the mask generated in our data pipeline to ensure that padded (zero) atoms do not contribute to the final graph embedding.
Hybrid Segment Pooling: Instead of physically partitioning the tensors, we use the molecule_indicator (batch index) to logically group atoms. We calculate both the Mean and the Max of the node states for each molecule.
Concatenation: The mean and max features are concatenated to form a robust, fixed-size graph-level representation.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Char length limit to 80 and add a bullet pointer for each item.

Comment on lines +303 to +330
print("Pre-featurizing Dataset...")
processed_data = []
for s in tqdm(df.smiles.values):
graph = smiles_to_graph(s)
if graph is None:
# Placeholder for failed molecules to maintain index alignment
processed_data.append(
(
np.zeros((MAX_ATOMS, atom_featurizer.dim)),
np.zeros((MAX_BONDS, bond_featurizer.dim)),
np.zeros((MAX_BONDS, 2), dtype="int32"),
np.zeros((MAX_ATOMS,)),
)
)
else:
processed_data.append(graph)

print("Pre-featurizing Dataset...")
processed_data = [
smiles_to_graph(s)
or (
np.zeros((MAX_ATOMS, atom_featurizer.dim)),
np.zeros((MAX_BONDS, bond_featurizer.dim)),
np.zeros((MAX_BONDS, 2), dtype="int32"),
np.zeros((MAX_ATOMS,)),
)
for s in tqdm(df.smiles.values)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be duplicate blocks, please check and retain only one block.

csv_path = keras.utils.get_file(
"BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
)
df = pd.read_csv(csv_path, usecols=[1, 2, 3])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use column names for better readability and easier reference.

Comment on lines +493 to +496
return ops.segment_sum(
messages,
ops.cast(pair_idx[:, 0], "int32"),
num_segments=BATCH_SIZE * MAX_ATOMS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of providing fixed value for num_segments, can't we give derived value like previous code, something like num_segments= ops.shape(atom_feat)[0]

Comment on lines +346 to 348
molecule
"""
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a mistake here? Seems to be extra here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants