Skip to content

Commit 718b658

Browse files
committed
Handle meta device better
1 parent b2fa19a commit 718b658

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

torchax/torchax/ops/jaten.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5192,6 +5192,6 @@ def update_indices(i, _indices):
51925192
@op(torch.ops.aten.linear)
51935193
def linear(input, weight, bias=None):
51945194
res = input @ jnp.transpose(weight)
5195-
if bias:
5195+
if bias is not None:
51965196
res += bias
51975197
return res

torchax/torchax/tensor.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -295,16 +295,20 @@ def get_as_jax_device(self, device: Any):
295295

296296
if isinstance(device, torch.device):
297297
device = str(device)
298-
if (self.config.use_torch_native_for_cpu_tensor and
299-
not device.startswith('jax') and not device.startswith('cuda')):
300-
return None
301298

302-
if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
303-
return None
304-
305-
if device == 'cpu':
299+
print('device ', device)
300+
if (not self.config.use_torch_native_for_cpu_tensor and
301+
device.startswith('cpu')):
306302
return jax.devices('cpu')[0]
307-
return jax.local_devices()[0]
303+
304+
if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
305+
return jax.local_devices()[0]
306+
307+
if device.startswith('jax'):
308+
return jax.local_devices()[0]
309+
310+
return None # fallback to torch
311+
308312

309313

310314
def load_ops(self):

0 commit comments

Comments
 (0)