Skip to content

Commit e6a70cc

Browse files
T5X Teamt5-copybara
authored andcommitted
Sort devices by their implicit order instead of explicitly by id. IDs may be randomly generated, so it's better to rely on the implicit order, which is currently based on (process index, id).
PiperOrigin-RevId: 650294623
1 parent efce74c commit e6a70cc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

t5x/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def create_global_mesh(mesh_shape, axis_names):
233233
size = np.prod(mesh_shape)
234234
if len(jax.devices()) < size:
235235
raise unittest.SkipTest(f'Test requires {size} global devices.')
236-
devices = sorted(jax.devices(), key=lambda d: d.id)
236+
devices = sorted(jax.devices())
237237
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
238238
global_mesh = Mesh(mesh_devices, axis_names)
239239
return global_mesh

0 commit comments

Comments
 (0)