diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 9e7f0ba22f3..917474fcba1 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -184,6 +184,8 @@ def setUp(self): self.env = torchax.default_env() torchax.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True + self.env.config.debug_print_each_op = True + self.env.config.debug_print_each_op_operands = True torch.manual_seed(0) self.old_var = self.env.config.use_torch_native_for_cpu_tensor self.env.config.use_torch_native_for_cpu_tensor = False diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 0e6e085719d..5428ca4815d 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -66,13 +66,12 @@ def jax_func(states, inputs): return states, jax_func def enable_globally(): - global env - env = default_env().__enter__() + env = default_env().enable_torch_modes() return env def disable_globally(): global env - default_env().__exit__(None, None, None) + default_env().disable_torch_modes() @contextlib.contextmanager def disable_temporarily(): diff --git a/torchax/torchax/config.py b/torchax/torchax/config.py index 351d137df57..f3fe410e46d 100644 --- a/torchax/torchax/config.py +++ b/torchax/torchax/config.py @@ -6,6 +6,7 @@ class Configuration: debug_print_each_op: bool = False debug_accuracy_for_each_op: bool = False debug_mixed_tensor: bool = False + debug_print_each_op_operands: bool = False use_int32_for_index: bool = False # Flash attention diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 6ddce255c02..f6d46eb20d0 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -207,6 +207,21 @@ def debug_accuracy(func, args, kwargs, current_output): return True +def _make_debug_msg(is_dispatch, log_args, func, args, kwargs): + def _display(a): + if isinstance(a, torch.Tensor): + return f'Tensor of {type(a)}: {a.dtype}{a.shape}' + elif isinstance(a, jax.Array): + return f'Jax Array of {type(a)}: {a.dtype}{a.shape}' + else: + return str(a) + + kwargs = kwargs or {} + title = 'DISPATCH' if is_dispatch else 'FUNCTION' + args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else '' + kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else '' + return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}' + class XLAFunctionMode(torch.overrides.TorchFunctionMode): """Context manager that dispatches torch function calls to JAX.""" @@ -219,7 +234,12 @@ def __torch_function__(self, types, args=(), kwargs=None) -> torch.Tensor: - with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'): + message = f'FUNCTION: {_name_of_func(func)}' + if self.env.config.debug_print_each_op_operands: + message = message + 'f' + message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands, + func, args, kwargs) + with log_nested(self.env, message): try: return self.env.dispatch(func, types, args, kwargs) except OperatorNotFound: @@ -237,7 +257,9 @@ def __init__(self, env): self.env = env def __torch_dispatch__(self, func, types, args=(), kwargs=None): - with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'): + message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands, + func, args, kwargs) + with log_nested(self.env, message): if isinstance(func, torch._ops.OpOverloadPacket): with self: return func(*args, **kwargs) @@ -295,6 +317,8 @@ def __init__(self, configuration=None): self._mesh = None self.config = configuration or config.Configuration() + self._manually_entered = False + self.enabled = False self._jax_devices = set(['jax', 'jax_cpu', 'xla']) def get_as_jax_device(self, device: Any): @@ -342,10 +366,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device): if new_dtype is not None and new_dtype != arr.dtype: arr = arr.astype(mappings.t2j_dtype(new_dtype)) if new_device is not None: - jax_device = self.get_as_jax_device(new_device) - if jax_device: - arr = jax.device_put(arr, jax_device) - else: + # convert xla tensor to other device + # only supported is CPU + if str(new_device).startswith('cpu'): # converting to a non-jax device: let torch native handle it torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): @@ -376,7 +399,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None): def _handle_tensor_constructor(self, func, args, kwargs): device = kwargs.get('device') jax_device = self.get_as_jax_device(device) - if jax_device is None: + # TODO(qihqi) figure out better ways for device propagation + if not self._manually_entered and jax_device is None: # let torch handle it with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return func(*args, **kwargs) @@ -465,17 +489,27 @@ def dispatch(self, func, types, args, kwargs): debug_accuracy(func, old_args, old_kwargs, res) return res - def __enter__(self): + def enable_torch_modes(self): self._dispatch_mode.__enter__() self._function_mode.__enter__() self.enabled = True - return self - - def __exit__(self, *exc): + + def disable_torch_modes(self, *exc): + if not exc: + exc = (None, None, None) self._function_mode.__exit__(*exc) self._dispatch_mode.__exit__(*exc) self.enabled = False + def __enter__(self): + self.enable_torch_modes() + self._manually_entered = True + return self + + def __exit__(self, *exc): + self._manually_entered = False + self.disable_torch_modes(*exc) + def _move_one_value(self, val): if isinstance(val, torch.nn.Module): with self: