Skip to content

Make JAX axis names deterministic across executions.#38161

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_926066643
Open

Make JAX axis names deterministic across executions.#38161
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_926066643

Conversation

@copybara-service
Copy link
Copy Markdown

Make JAX axis names deterministic across executions.

Previously, _TempAxisName used id(obj) (the memory address of the callable) to generate axis names. This resulted in non-deterministic names like: <axis 0x5105796d2020> that changed across different execution runs, making HLO fingerprints different across multiple runs.
This CL introduces a global counter and a process-local cache to assign sequential IDs like: <axis 0x0>, <axis 0x1> to anonymous axes. The cache ensures that the same function object still receives the same axis name within a single run, while the counter ensures determinism across runs.

Previously, `_TempAxisName` used `id(obj)` (the memory address of the callable) to generate axis names. This resulted in non-deterministic names like: `<axis 0x5105796d2020>` that changed across different execution runs, making HLO fingerprints different across multiple runs.
This CL introduces a global counter and a process-local cache to assign sequential IDs like: `<axis 0x0>`, `<axis 0x1>` to anonymous axes. The cache ensures that the same function object still receives the same axis name within a single run, while the counter ensures determinism across runs.

PiperOrigin-RevId: 926066643
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant