Skip to content

Conversation

@hertschuh
Copy link
Contributor

This was run on a cloud TPU v6e-1.

@hertschuh
Copy link
Contributor Author

@cantonios

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a new example demonstrating the use of DistributedEmbedding with the JAX backend on TPUs. The example is provided as a Python script, a Jupyter notebook, and a markdown file for documentation. The changes are a valuable addition. My review focuses on improving the clarity of the documentation and fixing a few potential correctness issues in the example code. The main points are correcting misleading references to TensorFlow in a JAX example, fixing incorrect batch sizes in the distributed configuration, and ensuring the model is in the correct mode during evaluation.

@hertschuh hertschuh force-pushed the jax_tpu_embedding branch 2 times, most recently from 3ae685d to aeab875 Compare August 4, 2025 23:16
This was run on a cloud TPU v6e-1.
Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment on lines +181 to +182
input_shape=(BATCH_SIZE,),
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
Copy link
Collaborator

@abheesht17 abheesht17 Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, one question. Is this BATCH_SIZE or PER_REPLICA_BATCH_SIZE? The documentation says PER_REPLICA_BATCH_SIZE: https://keras.io/keras_rs/api/embedding_layers/distributed_embedding/#configuration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the documentation. In the end, it's the global batch size.

@hertschuh hertschuh merged commit da45b18 into keras-team:master Aug 5, 2025
3 checks passed
@hertschuh hertschuh deleted the jax_tpu_embedding branch August 5, 2025 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants