Skip to content

Commit b674bcd

Browse files
T5X Teamt5-copybara
authored andcommitted
Fix export on H100.
PiperOrigin-RevId: 932589815
1 parent cc342d4 commit b674bcd

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

t5x/partitioning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
8080
# Must be passed the device at the highest-coordinate corner of the
8181
# relevant mesh, which is a requirement we know is satisfied by the last
8282
# device in jax.devices().
83-
if hasattr(last_device, 'coords') and len(last_device.coords) == 3:
83+
if (
84+
hasattr(last_device, 'coords')
85+
and len(last_device.coords) == 3
86+
and hasattr(last_device, 'core_on_chip')
87+
):
8488
x, y, z = last_device.coords
8589
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
8690
else:
@@ -91,7 +95,7 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
9195

9296
def get_coords(device: jax.Device) -> HardwareMesh:
9397
"""Returns the coordinates of the given device."""
94-
if hasattr(device, 'coords'):
98+
if hasattr(device, 'coords') and hasattr(device, 'core_on_chip'):
9599
return (*device.coords, device.core_on_chip)
96100
return (device.process_index, device.id % jax.local_device_count())
97101

0 commit comments

Comments
 (0)