Make JAX axis names deterministic across executions.#38161
Open
copybara-service[bot] wants to merge 1 commit into
Open
Make JAX axis names deterministic across executions.#38161copybara-service[bot] wants to merge 1 commit into
copybara-service[bot] wants to merge 1 commit into
Conversation
Previously, `_TempAxisName` used `id(obj)` (the memory address of the callable) to generate axis names. This resulted in non-deterministic names like: `<axis 0x5105796d2020>` that changed across different execution runs, making HLO fingerprints different across multiple runs. This CL introduces a global counter and a process-local cache to assign sequential IDs like: `<axis 0x0>`, `<axis 0x1>` to anonymous axes. The cache ensures that the same function object still receives the same axis name within a single run, while the counter ensures determinism across runs. PiperOrigin-RevId: 926066643
cfc4443 to
73e8f13
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Make JAX axis names deterministic across executions.
Previously,
_TempAxisNameusedid(obj)(the memory address of the callable) to generate axis names. This resulted in non-deterministic names like:<axis 0x5105796d2020>that changed across different execution runs, making HLO fingerprints different across multiple runs.This CL introduces a global counter and a process-local cache to assign sequential IDs like:
<axis 0x0>,<axis 0x1>to anonymous axes. The cache ensures that the same function object still receives the same axis name within a single run, while the counter ensures determinism across runs.