Migrating MPNN tutorial to Keras3#2289
Migrating MPNN tutorial to Keras3#2289LakshmiKalaKadali wants to merge 3 commits intokeras-team:masterfrom
Conversation
Summary of ChangesHello @LakshmiKalaKadali, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly updates the Message Passing Neural Network (MPNN) molecular property prediction tutorial to be fully compatible with Keras 3. The core changes involve adapting the graph representation to be backend-agnostic using fixed-size buffers and masks, and overhauling the data loading process with Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request does an excellent job of migrating the Message Passing Neural Network (MPNN) tutorial from Keras 2 to Keras 3, making it backend-agnostic. The changes significantly improve the code quality by introducing a more stable GatedReadout layer, optimizing the data pipeline with pre-featurization, and adopting modern Keras 3 practices like keras.ops and PyDataset. The model definition and training configuration have also been enhanced. I've identified a couple of areas for improvement: some redundant code has been introduced during the refactoring, and there's a minor formatting issue in a docstring. Overall, this is a high-quality contribution.
sachinprasadhs
left a comment
There was a problem hiding this comment.
Thanks! Left the review comments.
Below are some of the comments from the exisiting.
- Remove unnecessary underscores from "_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
implement an MPNN to predict a molecular property known as
blood-brain barrier permeability (BBBP)." section.
| 1. `molecule_from_smiles`: This takes a SMILES string as input and returns an RDKit molecule object. This process remains handled by RDKit on the CPU. | ||
| 2. `smiles_to_graph`: This takes a SMILES string and returns a graph represented as a four-tuple: (atom_features, bond_features, pair_indices, mask). | ||
| The original implementation utilized tf.RaggedTensor, which is exclusive to TensorFlow. To remain backend-agnostic and support JAX and PyTorch, we now use fixed-size buffers (MAX_ATOMS and MAX_BONDS). We also introduce a mask—a boolean array that allows the model to distinguish between valid chemical data and zero-padding. | ||
| Finally, implemented a pre-featurization step. Instead of featurizing during the training loop (which creates a CPU bottleneck), we process all SMILES once and store them in a list of NumPy arrays. This allows the GPU backends to run at 100% efficiency. |
There was a problem hiding this comment.
Restrict the line length to 80 characters for better readability, also, add extra line space after each item
| """ | ||
| ### Test the functions | ||
|
|
||
| We can now inspect a sample molecule and its corresponding graph representation. Note that the output shapes are now constant (e.g., 70 atoms and 150 bonds), ensuring compatibility across all Keras 3 backends. |
There was a problem hiding this comment.
Limit 80 char per line
| a = np.zeros((self.batch_size, MAX_ATOMS, atom_featurizer.dim), dtype="float32") | ||
| b = np.zeros((self.batch_size, MAX_BONDS, bond_featurizer.dim), dtype="float32") | ||
| p = np.zeros((self.batch_size, MAX_BONDS, 2), dtype="int32") | ||
| m = np.zeros((self.batch_size, MAX_ATOMS), dtype="float32") | ||
| y = np.zeros((self.batch_size, 1), dtype="float32") |
There was a problem hiding this comment.
Avoid char level variable name, always set proper variable name.
| The Message Passing Neural Network (MPNN) architecture implemented in this tutorial consists of three stages: message passing, readout, and classification. The message passing step is the core of the model, enabling information to flow through the molecular graph. It consists of two main components: | ||
|
|
There was a problem hiding this comment.
80 char length per line
| In this tutorial, we utilize a Gated Readout combined with Hybrid Pooling (Mean and Max). | ||
| This approach is highly stable and fully compatible with JAX, PyTorch, and TensorFlow. The process works as follows: | ||
| Gating Mechanism: Each node state passes through a learned gating function (using sigmoid and tanh activations). This allows the model to "decide" which atoms are most important for the molecular property being predicted. | ||
| Masking: We use the mask generated in our data pipeline to ensure that padded (zero) atoms do not contribute to the final graph embedding. | ||
| Hybrid Segment Pooling: Instead of physically partitioning the tensors, we use the molecule_indicator (batch index) to logically group atoms. We calculate both the Mean and the Max of the node states for each molecule. | ||
| Concatenation: The mean and max features are concatenated to form a robust, fixed-size graph-level representation. | ||
| """ |
There was a problem hiding this comment.
Char length limit to 80 and add a bullet pointer for each item.
| print("Pre-featurizing Dataset...") | ||
| processed_data = [] | ||
| for s in tqdm(df.smiles.values): | ||
| graph = smiles_to_graph(s) | ||
| if graph is None: | ||
| # Placeholder for failed molecules to maintain index alignment | ||
| processed_data.append( | ||
| ( | ||
| np.zeros((MAX_ATOMS, atom_featurizer.dim)), | ||
| np.zeros((MAX_BONDS, bond_featurizer.dim)), | ||
| np.zeros((MAX_BONDS, 2), dtype="int32"), | ||
| np.zeros((MAX_ATOMS,)), | ||
| ) | ||
| ) | ||
| else: | ||
| processed_data.append(graph) | ||
|
|
||
| print("Pre-featurizing Dataset...") | ||
| processed_data = [ | ||
| smiles_to_graph(s) | ||
| or ( | ||
| np.zeros((MAX_ATOMS, atom_featurizer.dim)), | ||
| np.zeros((MAX_BONDS, bond_featurizer.dim)), | ||
| np.zeros((MAX_BONDS, 2), dtype="int32"), | ||
| np.zeros((MAX_ATOMS,)), | ||
| ) | ||
| for s in tqdm(df.smiles.values) | ||
| ] |
There was a problem hiding this comment.
There seem to be duplicate blocks, please check and retain only one block.
| csv_path = keras.utils.get_file( | ||
| "BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv" | ||
| ) | ||
| df = pd.read_csv(csv_path, usecols=[1, 2, 3]) |
There was a problem hiding this comment.
Use column names for better readability and easier reference.
| return ops.segment_sum( | ||
| messages, | ||
| ops.cast(pair_idx[:, 0], "int32"), | ||
| num_segments=BATCH_SIZE * MAX_ATOMS, |
There was a problem hiding this comment.
Instead of providing fixed value for num_segments, can't we give derived value like previous code, something like num_segments= ops.shape(atom_feat)[0]
| molecule | ||
| """ | ||
| """ |
There was a problem hiding this comment.
Is this a mistake here? Seems to be extra here.
This PR migrates the Message Passing Neural Network (MPNN) molecular property prediction tutorial from Keras 2 to Keras 3. Replaced the legacy PartitionPadding and Transformer readout with a GatedReadout layer. This provides a more stable attention mechanism for small datasets like BBBP. gist