Skip to content

Commit d5bf687

Browse files
committed
update modes
Distinguish when someone enables env via context manager vs. enable_globally. The reason is that: if a user calls `torch.ones` we want to respect the device, specially meta. However, if this op is called as part of a lowering, then we need to replace with jax device. Currently there is no way for C++ dicomposed ops to know if a tensor is of device jax
1 parent f57b3e3 commit d5bf687

File tree

4 files changed

+50
-14
lines changed

4 files changed

+50
-14
lines changed

torchax/test/test_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def setUp(self):
184184
self.env = torchax.default_env()
185185
torchax.enable_accuracy_mode()
186186
#self.env.config.debug_accuracy_for_each_op = True
187+
self.env.config.debug_print_each_op = True
188+
self.env.config.debug_print_each_op_operands = True
187189
torch.manual_seed(0)
188190
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
189191
self.env.config.use_torch_native_for_cpu_tensor = False

torchax/torchax/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,12 @@ def jax_func(states, inputs):
6666
return states, jax_func
6767

6868
def enable_globally():
69-
global env
70-
env = default_env().__enter__()
69+
env = default_env().enable_torch_modes()
7170
return env
7271

7372
def disable_globally():
7473
global env
75-
default_env().__exit__(None, None, None)
74+
default_env().disable_torch_modes()
7675

7776
@contextlib.contextmanager
7877
def disable_temporarily():

torchax/torchax/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class Configuration:
66
debug_print_each_op: bool = False
77
debug_accuracy_for_each_op: bool = False
88
debug_mixed_tensor: bool = False
9+
debug_print_each_op_operands: bool = False
910
use_int32_for_index: bool = False
1011

1112
# Flash attention

torchax/torchax/tensor.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,21 @@ def debug_accuracy(func, args, kwargs, current_output):
207207

208208
return True
209209

210+
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
211+
def _display(a):
212+
if isinstance(a, torch.Tensor):
213+
return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
214+
elif isinstance(a, jax.Array):
215+
return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
216+
else:
217+
return str(a)
218+
219+
kwargs = kwargs or {}
220+
title = 'DISPATCH' if is_dispatch else 'FUNCTION'
221+
args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
222+
kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
223+
return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
224+
210225

211226
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
212227
"""Context manager that dispatches torch function calls to JAX."""
@@ -219,7 +234,12 @@ def __torch_function__(self,
219234
types,
220235
args=(),
221236
kwargs=None) -> torch.Tensor:
222-
with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'):
237+
message = f'FUNCTION: {_name_of_func(func)}'
238+
if self.env.config.debug_print_each_op_operands:
239+
message = message + 'f'
240+
message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
241+
func, args, kwargs)
242+
with log_nested(self.env, message):
223243
try:
224244
return self.env.dispatch(func, types, args, kwargs)
225245
except OperatorNotFound:
@@ -237,7 +257,9 @@ def __init__(self, env):
237257
self.env = env
238258

239259
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
240-
with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'):
260+
message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
261+
func, args, kwargs)
262+
with log_nested(self.env, message):
241263
if isinstance(func, torch._ops.OpOverloadPacket):
242264
with self:
243265
return func(*args, **kwargs)
@@ -295,6 +317,8 @@ def __init__(self, configuration=None):
295317
self._mesh = None
296318
self.config = configuration or config.Configuration()
297319

320+
self._manually_entered = False
321+
self.enabled = False
298322
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
299323

300324
def get_as_jax_device(self, device: Any):
@@ -342,10 +366,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
342366
if new_dtype is not None and new_dtype != arr.dtype:
343367
arr = arr.astype(mappings.t2j_dtype(new_dtype))
344368
if new_device is not None:
345-
jax_device = self.get_as_jax_device(new_device)
346-
if jax_device:
347-
arr = jax.device_put(arr, jax_device)
348-
else:
369+
# convert xla tensor to other device
370+
# only supported is CPU
371+
if str(new_device).startswith('cpu'):
349372
# converting to a non-jax device: let torch native handle it
350373
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
351374
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
@@ -376,7 +399,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
376399
def _handle_tensor_constructor(self, func, args, kwargs):
377400
device = kwargs.get('device')
378401
jax_device = self.get_as_jax_device(device)
379-
if jax_device is None:
402+
# TODO(qihqi) figure out better ways for device propagation
403+
if not self._manually_entered and jax_device is None:
380404
# let torch handle it
381405
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
382406
return func(*args, **kwargs)
@@ -465,17 +489,27 @@ def dispatch(self, func, types, args, kwargs):
465489
debug_accuracy(func, old_args, old_kwargs, res)
466490
return res
467491

468-
def __enter__(self):
492+
def enable_torch_modes(self):
469493
self._dispatch_mode.__enter__()
470494
self._function_mode.__enter__()
471495
self.enabled = True
472-
return self
473-
474-
def __exit__(self, *exc):
496+
497+
def disable_torch_modes(self, *exc):
498+
if not exc:
499+
exc = (None, None, None)
475500
self._function_mode.__exit__(*exc)
476501
self._dispatch_mode.__exit__(*exc)
477502
self.enabled = False
478503

504+
def __enter__(self):
505+
self.enable_torch_modes()
506+
self._manually_entered = True
507+
return self
508+
509+
def __exit__(self, *exc):
510+
self._manually_entered = False
511+
self.disable_torch_modes(*exc)
512+
479513
def _move_one_value(self, val):
480514
if isinstance(val, torch.nn.Module):
481515
with self:

0 commit comments

Comments
 (0)