-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add DistributedEmbedding example for TPU on JAX. #2132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a new example demonstrating the use of DistributedEmbedding with the JAX backend on TPUs. The example is provided as a Python script, a Jupyter notebook, and a markdown file for documentation. The changes are a valuable addition. My review focuses on improving the clarity of the documentation and fixing a few potential correctness issues in the example code. The main points are correcting misleading references to TensorFlow in a JAX example, fixing incorrect batch sizes in the distributed configuration, and ensuring the model is in the correct mode during evaluation.
3ae685d to
aeab875
Compare
This was run on a cloud TPU v6e-1.
aeab875 to
b8d7564
Compare
abheesht17
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
| input_shape=(BATCH_SIZE,), | ||
| output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, one question. Is this BATCH_SIZE or PER_REPLICA_BATCH_SIZE? The documentation says PER_REPLICA_BATCH_SIZE: https://keras.io/keras_rs/api/embedding_layers/distributed_embedding/#configuration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update the documentation. In the end, it's the global batch size.
This was run on a cloud TPU v6e-1.