Skip to content

Commit 3578940

Browse files
lsy323Siyuan Liu
and
Siyuan Liu
authored
[torchax][distributed] Use local device for default jax device (#8667)
Co-authored-by: Siyuan Liu <[email protected]>
1 parent 9ae017e commit 3578940

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchax/torchax/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def get_as_jax_device(self, device: Any):
304304

305305
if device == 'cpu':
306306
return jax.devices('cpu')[0]
307-
return jax.devices()[0]
307+
return jax.local_devices()[0]
308308

309309

310310
def load_ops(self):

0 commit comments

Comments
 (0)