We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent bc92c67 commit c42230bCopy full SHA for c42230b
checkpoint/orbax/checkpoint/experimental/emergency/multihost.py
@@ -94,7 +94,7 @@ def process_index_from_device_id(device_id: int) -> int:
94
if num_slices > 1:
95
num_processes_per_slice = jax.process_count() // num_slices
96
# This is based on how Megascale device ids are assigned.
97
- # See platforms/xla/megascale/runtime/common/multi_slice_topology.h.
+ # See platforms/xla/megascale/common/multi_slice_topology.h.
98
slice_id = device_id // 100000 - 1
99
local_process_id = device_id % 100000 // jax.local_device_count()
100
return slice_id * num_processes_per_slice + local_process_id
0 commit comments