Skip to content

Migrating GNN node classification tutorial to Keras3#2304

Open
LakshmiKalaKadali wants to merge 3 commits intokeras-team:masterfrom
LakshmiKalaKadali:gnn_br
Open

Migrating GNN node classification tutorial to Keras3#2304
LakshmiKalaKadali wants to merge 3 commits intokeras-team:masterfrom
LakshmiKalaKadali:gnn_br

Conversation

@LakshmiKalaKadali
Copy link
Contributor

The PR migrates the GNN tutorial to Keras3. The changes included to make the network numerically stable are
Self-Loops: Integrated automatic addition of self-loops in the GNNNodeClassifier constructor to ensure nodes retain their own identity during message passing.
Degree Normalization: Implemented manual segment-based mean pooling (Sum / Degree) in the GraphConvLayer. This prevents feature explosion in high-degree nodes and stabilizes the validation loss.
Input Feature Normalization: Added row-normalization for binary word vectors in the model's initialization to ensure consistent feature scales.colab ref notebook

Thank You

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request undertakes a comprehensive migration of the GNN node classification tutorial to Keras 3, ensuring compatibility with the latest Keras API and enabling multi-backend support. Beyond the API update, significant architectural improvements were introduced to enhance the numerical stability of the GNN model, particularly concerning message passing and aggregation. The data preparation pipeline was also refined to include a dedicated validation set, leading to more robust model evaluation and training practices.

Highlights

  • Keras 3 Migration: The entire GNN node classification tutorial has been migrated to use the Keras 3 API, leveraging keras.ops for backend-agnostic tensor operations and updating layer configurations.
  • Numerical Stability Enhancements: Implemented several changes to improve numerical stability, including automatic addition of self-loops in the GNNNodeClassifier, manual segment-based mean pooling for degree normalization in GraphConvLayer, and row-normalization for input binary word vectors.
  • Improved Dataset Splitting: The dataset splitting logic was refactored to explicitly create stratified train, validation, and test sets, enhancing reproducibility and providing a dedicated validation set for model training and early stopping.
  • GraphConvLayer and GNNNodeClassifier Refinements: The custom GraphConvLayer and GNNNodeClassifier classes were extensively updated to be Keras 3 compatible, incorporating keras.ops for all tensor manipulations, adding training arguments to call methods, and introducing configurable edge weight normalization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/graph/gnn_citations.py
    • Added author credit for Keras 3 conversion.
    • Updated imports to use Keras 3 (keras instead of tensorflow.keras) and set JAX as the default backend.
    • Configured matplotlib to use a non-interactive backend (Agg).
    • Introduced keras.utils.set_random_seed and np.random.default_rng for reproducibility.
    • Adjusted the data_dir path for the Cora dataset extraction.
    • Added num_classes variable for clarity.
    • Refactored dataset splitting to create explicit stratified train, validation, and test sets using train_ids, val_ids, test_ids.
    • Modified run_experiment to accept explicit validation data, changed EarlyStopping monitor to 'val_loss', and removed validation_split.
    • Updated display_learning_curves to accept an optional title and changed legend labels from 'test' to 'val'.
    • Replaced tf.nn.gelu with the string literal 'gelu' in create_ffn for Keras 3 compatibility.
    • Refactored baseline model data preparation to use new train/val/test indices and keras.ops.
    • Updated create_baseline_model and run_experiment calls to align with new data splitting and Keras 3 API.
    • Replaced tf.convert_to_tensor and .numpy() with ops.convert_to_numpy and ops.convert_to_tensor for Keras 3 compatibility in probability calculations.
    • Converted tf.ones and tf.cast to ops.ones and ops.cast for Keras 3 compatibility.
    • Updated GraphConvLayer docstrings to reflect Keras 3 and keras.ops usage for aggregation.
    • Changed keras.layers.Input to layers.Input in create_gru.
    • Removed deprecated return_state and recurrent_dropout arguments from layers.GRU.
    • Extensively refactored GraphConvLayer to use keras.ops for all tensor operations, accept training argument, and simplify __init__ and update_fn logic.
    • Significantly refactored GNNNodeClassifier to be Keras 3 compatible, adding add_self_loops and edge_weight_normalization parameters, implementing self-loop and normalization logic using keras.ops, and updating call method to pass training argument.
    • Updated gnn_model instantiation with new parameters and adjusted input for shape printing.
    • Updated GNN model training and evaluation calls to use the new train/val/test indices and labels.
    • Refactored the inductive learning section to use keras.ops for tensor conversions and np.random.choice with replace=False for robust sampling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully migrates the GNN node classification tutorial to Keras 3, making it backend-agnostic by using keras.ops. Beyond the migration, you've introduced several excellent improvements that enhance the model's stability and the example's robustness, including adding self-loops, degree normalization, and a more rigorous data splitting and evaluation setup. My review includes a few minor suggestions to improve reproducibility by consistently using the seeded random number generator and to ensure the generated plots are saved correctly.

ax2.legend(["train", "test"], loc="upper right")
ax2.legend(["train", "val"], loc="upper right")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Accuracy")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the switch to the 'Agg' backend for matplotlib, the learning curve plots are generated but not displayed or saved. To make them useful, you should save them to a file. This can be done by adding a plt.savefig() call at the end of the function.

    ax2.set_ylabel("Accuracy")
    if title:
        filename = f"{title.lower().replace(' ', '_')}_learning_curves.png"
        plt.savefig(filename)
    plt.close(fig)

token_probability = x_train_base.mean(axis=0)
instances = []
for _ in range(num_instances):
probabilities = np.random.uniform(size=len(token_probability))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For reproducibility, it's better to use the rng object that was initialized with a fixed seed, instead of np.random.uniform. This will ensure that the generated random instances are the same across runs.

        probabilities = rng.uniform(size=len(token_probability))

Comment on lines +724 to +728
selected_paper_indices1 = np.random.choice(subject_papers, 5, replace=False)

selected_paper_indices2 = np.random.choice(
papers.paper_id.to_numpy(), 2, replace=False
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For reproducibility, it's better to use the rng object that was initialized with a fixed seed, instead of np.random.choice. This will ensure that the generated citations for new nodes are the same across runs.

    selected_paper_indices1 = rng.choice(subject_papers, 5, replace=False)

    selected_paper_indices2 = rng.choice(
        papers.paper_id.to_numpy(), 2, replace=False
    )

jeffcarp
jeffcarp previously approved these changes Mar 5, 2026
Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you generate the ipynb and md files?
https://github.com/keras-team/keras-io?tab=readme-ov-file#creating-a-new-example-starting-from-a-python-script

[Edit] and update formatting so black passes

@jeffcarp jeffcarp dismissed their stale review March 5, 2026 19:47

Removing review - needs extra files and reformat

@LakshmiKalaKadali LakshmiKalaKadali changed the title Migrating GNN node classificationntutorial to Keras3 Migrating GNN node classification tutorial to Keras3 Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants