File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -80,7 +80,11 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
8080 # Must be passed the device at the highest-coordinate corner of the
8181 # relevant mesh, which is a requirement we know is satisfied by the last
8282 # device in jax.devices().
83- if hasattr (last_device , 'coords' ) and len (last_device .coords ) == 3 :
83+ if (
84+ hasattr (last_device , 'coords' )
85+ and len (last_device .coords ) == 3
86+ and hasattr (last_device , 'core_on_chip' )
87+ ):
8488 x , y , z = last_device .coords
8589 return x + 1 , y + 1 , z + 1 , last_device .core_on_chip + 1
8690 else :
@@ -91,7 +95,7 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
9195
9296def get_coords (device : jax .Device ) -> HardwareMesh :
9397 """Returns the coordinates of the given device."""
94- if hasattr (device , 'coords' ):
98+ if hasattr (device , 'coords' ) and hasattr ( device , 'core_on_chip' ) :
9599 return (* device .coords , device .core_on_chip )
96100 return (device .process_index , device .id % jax .local_device_count ())
97101
You can’t perform that action at this time.
0 commit comments