Skip to content

Migrating MPNN tutorial to Keras3#2289

Open
LakshmiKalaKadali wants to merge 5 commits intokeras-team:masterfrom
LakshmiKalaKadali:mpnn_br
Open

Migrating MPNN tutorial to Keras3#2289
LakshmiKalaKadali wants to merge 5 commits intokeras-team:masterfrom
LakshmiKalaKadali:mpnn_br

Conversation

@LakshmiKalaKadali
Copy link
Copy Markdown
Contributor

This PR migrates the Message Passing Neural Network (MPNN) molecular property prediction tutorial from Keras 2 to Keras 3. Replaced the legacy PartitionPadding and Transformer readout with a GatedReadout layer. This provides a more stable attention mechanism for small datasets like BBBP. gist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @LakshmiKalaKadali, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly updates the Message Passing Neural Network (MPNN) molecular property prediction tutorial to be fully compatible with Keras 3. The core changes involve adapting the graph representation to be backend-agnostic using fixed-size buffers and masks, and overhauling the data loading process with keras.utils.PyDataset. A key architectural improvement is the replacement of the previous readout mechanism with a more stable GatedReadout layer. Additionally, training stability and performance are enhanced through the adoption of AdamW with a CosineDecay learning rate schedule and the introduction of pre-featurization to reduce CPU bottlenecks.

Highlights

  • Keras 3 Migration: The tutorial has been fully migrated to Keras 3, ensuring compatibility with JAX, PyTorch, and TensorFlow backends. This involved updating imports, API calls, and data handling mechanisms.
  • Backend-Agnostic Graph Representation: The graph generation process (smiles_to_graph) was refactored to use fixed-size NumPy arrays and masks instead of TensorFlow's RaggedTensor. This ensures static shapes required for JAX and PyTorch compatibility and introduces pre-featurization for improved training efficiency.
  • Updated Readout Mechanism: The legacy PartitionPadding and TransformerEncoderReadout layers have been replaced with a GatedReadout layer. This new readout combines mean and max pooling with a gating mechanism, providing a more stable attention mechanism, especially for smaller datasets like BBBP.
  • Optimized Data Loading: The data loading pipeline was re-implemented using keras.utils.PyDataset, replacing the tf.data.Dataset approach. This custom PyDataset handles batching, padding, and the creation of molecule_indicator and mask tensors for the new graph representation.
  • Enhanced Model Training: The model now utilizes the AdamW optimizer with a CosineDecay learning rate schedule, along with weight_decay and global_clipnorm for improved training stability and performance. LayerNormalization was also added within the MessagePassing layer.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/graph/mpnn-molecular-graphs.py
    • Added Keras 3 conversion attribution to the file header.
    • Updated rdkit and graphviz installation commands in the tutorial text.
    • Modified imports to use keras and keras.ops instead of tensorflow and tensorflow.keras.
    • Removed TF_CPP_MIN_LOG_LEVEL environment variable and IPythonConsole import.
    • Introduced KERAS_BACKEND environment variable for Keras 3 backend selection.
    • Added global configuration constants for MAX_ATOMS, MAX_BONDS, BATCH_SIZE, EPOCHS, and LEARNING_RATE.
    • Replaced np.random.seed and tf.random.set_seed with keras.utils.set_random_seed.
    • Changed np.zeros dtype to float32 in Featurizer.encode and BondFeaturizer.encode.
    • Removed redundant __init__ method from AtomFeaturizer.
    • Updated atom_featurizer and bond_featurizer instantiation to remove allowable_sets keyword.
    • Refactored graph generation: Replaced graph_from_molecule and graphs_from_smiles with a new smiles_to_graph function.
    • Implemented fixed-size buffers and a mask in smiles_to_graph for Keras 3 backend compatibility.
    • Introduced a pre-featurization step for the entire dataset to improve training efficiency.
    • Replaced tf.data.Dataset and prepare_batch with a custom MPNNDataset class inheriting from keras.utils.PyDataset.
    • Updated dataset splitting logic to use the new MPNNDataset.
    • Modified EdgeNetwork to use keras.ops for tensor operations and adjusted num_segments for batch processing.
    • Updated MessagePassing layer to use keras.ops, LayerNormalization, and simplified padding logic.
    • Replaced PartitionPadding and TransformerEncoderReadout layers with a new GatedReadout layer.
    • Updated MPNNModel definition to accept mask input and incorporate the GatedReadout layer.
    • Added kernel_regularizer and Dropout to the final dense layers in MPNNModel.
    • Configured the model compilation with AdamW optimizer, CosineDecay learning rate schedule, weight_decay, and global_clipnorm.
    • Adjusted training epochs to use the global EPOCHS constant and verbose level.
    • Modified plotting code to save the AUC plot to a file.
    • Updated prediction logic to use ops.convert_to_numpy for y_pred.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does an excellent job of migrating the Message Passing Neural Network (MPNN) tutorial from Keras 2 to Keras 3, making it backend-agnostic. The changes significantly improve the code quality by introducing a more stable GatedReadout layer, optimizing the data pipeline with pre-featurization, and adopting modern Keras 3 practices like keras.ops and PyDataset. The model definition and training configuration have also been enhanced. I've identified a couple of areas for improvement: some redundant code has been introduced during the refactoring, and there's a minor formatting issue in a docstring. Overall, this is a high-quality contribution.

Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
@LakshmiKalaKadali LakshmiKalaKadali changed the title Migratimg MPNN tutorial to Keras3 Migrating MPNN tutorial to Keras3 Feb 17, 2026
Copy link
Copy Markdown
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left the review comments.

Below are some of the comments from the exisiting.

  1. Remove unnecessary underscores from "_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
    implement an MPNN to predict a molecular property known as
    blood-brain barrier permeability (BBBP)." section.

Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Comment thread examples/graph/mpnn-molecular-graphs.py Outdated
Copy link
Copy Markdown
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks

Copy link
Copy Markdown
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, none of the previous comments looks addressed! everything is marked as resolved but none are addressed!

Copy link
Copy Markdown
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one small comment.
Make the change and generate other files with the changed code.
Rest all looks good.

In this tutorial, we will implement a type of graph neural network (GNN) known as
_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
In this tutorial, we will implement a type of graph neural network (GNN)
known as _ message passing neural network_ (MPNN) to predict graph
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change it to _message passing neural network_ by removing space after first _

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants