Skip to content

Conversation

@hertschuh
Copy link
Contributor

@hertschuh hertschuh commented Sep 3, 2025

This was run on a cloud TPU v6e-1.

Also tweaked some comments in the JAX DistributedEmbedding example.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request adds a comprehensive example demonstrating distributed embedding training on TensorFlow using TPU SparseCore. It provides a practical guide for leveraging keras_rs.layers.DistributedEmbedding for large-scale recommendation systems, complementing the existing JAX-based example.

Highlights

  • New TensorFlow DistributedEmbedding Example: Introduces a new example (distributed_embedding_tf.py) demonstrating the use of keras_rs.layers.DistributedEmbedding for movie ranking on TensorFlow with TPU SparseCore.
  • Jupyter Notebook and Markdown Versions: Accompanying Jupyter Notebook (.ipynb) and Markdown (.md) versions of the example are added for easier consumption and documentation.
  • Example Integration: The new TensorFlow example is integrated into the rs_master.py script, making it discoverable alongside existing examples.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a new example demonstrating the use of keras_rs.layers.DistributedEmbedding with TensorFlow on TPUs. The example is well-structured and provides a clear walkthrough of setting up the TPU strategy, preparing the dataset, configuring the distributed embedding layer, and training a ranking model. I've found a couple of potential issues related to the configuration of FeatureConfig and the usage of the embedding layer's output, which could lead to runtime errors. My detailed comments are below.

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! It's great news that this works on TPUs :)


"""shell
pip install -U -q tensorflow-tpu==2.19.1
pip install -q keras-rs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'd mentioned this before: should we add the optional square bracket thing to KerasRS setup files, like so: pip install -q keras-rs[tpu]? Or pip install -q keras-rs[dist-emb-tpu] or something?

Copy link
Contributor Author

@hertschuh hertschuh Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question, we should experiment with this. Part of the complication is the combinatorial of backends and hardware. One other issue I have faced is that some packages clash in the version they want for their dependencies (protobuf or keras for instance).

I think we would have to do:

  • keras-rs[tf-tpu] (adds tensorflow-tpu==2.19.1)
  • keras-rs[jax-tpu] (adds jax-tpu-embedding and jax[tpu])

SparseCore chips of all the available TPUs.
"""

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[knowledge] I'm assuming this won't work in the multi-host case, right? Since tpu = "local"

Copy link
Contributor Author

@hertschuh hertschuh Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Is it just a question of passing the name of the TPU cluster? If so, I can add a variable and a comment explaining how to do it. But I haven't tested it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you can pass the GCP zone, project, etc. too. I just asked this question for knowledge. We don't have to add since we haven't tested it yet. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't get my hands on a multi-host v6e right now. We can submit as-is and revisit.

Comment on lines +73 to +85
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we have to do this instead of using MirroredStrategy? Maybe, we can add a note here for the reader as to why this is necessary? What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that's what the text lines 64-69 are for. If you're on TPU, you have to use a TPUStrategy.

https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy

This strategy is typically used for training on one machine with multiple GPUs. For TPUs, use tf.distribute.TPUStrategy

I could add a link to the TensorFlow documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add a link to the TensorFlow documentation.

Yeah, let's add it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, replied to some comments

SparseCore chips of all the available TPUs.
"""

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think you can pass the GCP zone, project, etc. too. I just asked this question for knowledge. We don't have to add since we haven't tested it yet. What do you think?

Comment on lines +73 to +85
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add a link to the TensorFlow documentation.

Yeah, let's add it?

@hertschuh hertschuh merged commit 8340b57 into keras-team:master Sep 5, 2025
3 checks passed
@hertschuh hertschuh deleted the tf_tpu_embedding branch September 5, 2025 00:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants