Skip to content

Commit

Permalink
update modes
Browse files Browse the repository at this point in the history
 Distinguish when someone enables env via context manager vs.
  enable_globally.
The reason is that: if a user calls `torch.ones` we want to respect
the device, specially meta. However, if this op is called as part
of a lowering, then we need to replace with jax device.

Currently there is no way for C++ dicomposed ops to know if a tensor
is of device jax
  • Loading branch information
qihqi committed Feb 6, 2025
1 parent f57b3e3 commit d5bf687
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
2 changes: 2 additions & 0 deletions torchax/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def setUp(self):
self.env = torchax.default_env()
torchax.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
self.env.config.debug_print_each_op = True
self.env.config.debug_print_each_op_operands = True
torch.manual_seed(0)
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False
Expand Down
5 changes: 2 additions & 3 deletions torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def jax_func(states, inputs):
return states, jax_func

def enable_globally():
global env
env = default_env().__enter__()
env = default_env().enable_torch_modes()
return env

def disable_globally():
global env
default_env().__exit__(None, None, None)
default_env().disable_torch_modes()

@contextlib.contextmanager
def disable_temporarily():
Expand Down
1 change: 1 addition & 0 deletions torchax/torchax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class Configuration:
debug_print_each_op: bool = False
debug_accuracy_for_each_op: bool = False
debug_mixed_tensor: bool = False
debug_print_each_op_operands: bool = False
use_int32_for_index: bool = False

# Flash attention
Expand Down
56 changes: 45 additions & 11 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,21 @@ def debug_accuracy(func, args, kwargs, current_output):

return True

def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
def _display(a):
if isinstance(a, torch.Tensor):
return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
elif isinstance(a, jax.Array):
return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
else:
return str(a)

kwargs = kwargs or {}
title = 'DISPATCH' if is_dispatch else 'FUNCTION'
args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'


class XLAFunctionMode(torch.overrides.TorchFunctionMode):
"""Context manager that dispatches torch function calls to JAX."""
Expand All @@ -219,7 +234,12 @@ def __torch_function__(self,
types,
args=(),
kwargs=None) -> torch.Tensor:
with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'):
message = f'FUNCTION: {_name_of_func(func)}'
if self.env.config.debug_print_each_op_operands:
message = message + 'f'
message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
func, args, kwargs)
with log_nested(self.env, message):
try:
return self.env.dispatch(func, types, args, kwargs)
except OperatorNotFound:
Expand All @@ -237,7 +257,9 @@ def __init__(self, env):
self.env = env

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'):
message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
func, args, kwargs)
with log_nested(self.env, message):
if isinstance(func, torch._ops.OpOverloadPacket):
with self:
return func(*args, **kwargs)
Expand Down Expand Up @@ -295,6 +317,8 @@ def __init__(self, configuration=None):
self._mesh = None
self.config = configuration or config.Configuration()

self._manually_entered = False
self.enabled = False
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])

def get_as_jax_device(self, device: Any):
Expand Down Expand Up @@ -342,10 +366,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
if new_dtype is not None and new_dtype != arr.dtype:
arr = arr.astype(mappings.t2j_dtype(new_dtype))
if new_device is not None:
jax_device = self.get_as_jax_device(new_device)
if jax_device:
arr = jax.device_put(arr, jax_device)
else:
# convert xla tensor to other device
# only supported is CPU
if str(new_device).startswith('cpu'):
# converting to a non-jax device: let torch native handle it
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
Expand Down Expand Up @@ -376,7 +399,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
def _handle_tensor_constructor(self, func, args, kwargs):
device = kwargs.get('device')
jax_device = self.get_as_jax_device(device)
if jax_device is None:
# TODO(qihqi) figure out better ways for device propagation
if not self._manually_entered and jax_device is None:
# let torch handle it
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
return func(*args, **kwargs)
Expand Down Expand Up @@ -465,17 +489,27 @@ def dispatch(self, func, types, args, kwargs):
debug_accuracy(func, old_args, old_kwargs, res)
return res

def __enter__(self):
def enable_torch_modes(self):
self._dispatch_mode.__enter__()
self._function_mode.__enter__()
self.enabled = True
return self

def __exit__(self, *exc):

def disable_torch_modes(self, *exc):
if not exc:
exc = (None, None, None)
self._function_mode.__exit__(*exc)
self._dispatch_mode.__exit__(*exc)
self.enabled = False

def __enter__(self):
self.enable_torch_modes()
self._manually_entered = True
return self

def __exit__(self, *exc):
self._manually_entered = False
self.disable_torch_modes(*exc)

def _move_one_value(self, val):
if isinstance(val, torch.nn.Module):
with self:
Expand Down

0 comments on commit d5bf687

Please sign in to comment.