Skip to content

Commit c42230b

Browse files
author
Orbax Authors
committed
Internal
PiperOrigin-RevId: 867385615
1 parent bc92c67 commit c42230b

File tree

1 file changed

+1
-1
lines changed
  • checkpoint/orbax/checkpoint/experimental/emergency

1 file changed

+1
-1
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/multihost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def process_index_from_device_id(device_id: int) -> int:
9494
if num_slices > 1:
9595
num_processes_per_slice = jax.process_count() // num_slices
9696
# This is based on how Megascale device ids are assigned.
97-
# See platforms/xla/megascale/runtime/common/multi_slice_topology.h.
97+
# See platforms/xla/megascale/common/multi_slice_topology.h.
9898
slice_id = device_id // 100000 - 1
9999
local_process_id = device_id % 100000 // jax.local_device_count()
100100
return slice_id * num_processes_per_slice + local_process_id

0 commit comments

Comments
 (0)