Skip to content

Commit

Permalink
Misc fixes (#8679)
Browse files Browse the repository at this point in the history
better handle meta device
.data obj
linear should check bias with is not None
  • Loading branch information
qihqi authored Feb 6, 2025
1 parent 45109d1 commit 8b45e59
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 23 deletions.
13 changes: 12 additions & 1 deletion torchax/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_rnn(self):
model = SeqModel()
x = torch.randn((2, 100, 20))
res = model(x)
self.env.config.debug_print_each_op = True
with self.env:
model.to('jax')
x = x.to('jax')
Expand All @@ -79,6 +78,18 @@ def test_rnn(self):

self.assertEqual(res.shape, res2.shape)

def test_rms_norm(self):
model = torch.nn.RMSNorm((100, 20))
x = torch.randn((2, 100, 20))
res = model(x)

with self.env:
model.to('jax')
x = x.to('jax')
res2 = model(x)
self.assertTrue(
torch.allclose(res, torchax.tensor.j2t(res2.jax())))




Expand Down
2 changes: 2 additions & 0 deletions torchax/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def setUp(self):
self.env = torchax.default_env()
torchax.enable_accuracy_mode()
#self.env.config.debug_accuracy_for_each_op = True
self.env.config.debug_print_each_op = True
self.env.config.debug_print_each_op_operands = True
torch.manual_seed(0)
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
self.env.config.use_torch_native_for_cpu_tensor = False
Expand Down
5 changes: 2 additions & 3 deletions torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def jax_func(states, inputs):
return states, jax_func

def enable_globally():
global env
env = default_env().__enter__()
env = default_env().enable_torch_modes()
return env

def disable_globally():
global env
default_env().__exit__(None, None, None)
default_env().disable_torch_modes()

@contextlib.contextmanager
def disable_temporarily():
Expand Down
1 change: 1 addition & 0 deletions torchax/torchax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class Configuration:
debug_print_each_op: bool = False
debug_accuracy_for_each_op: bool = False
debug_mixed_tensor: bool = False
debug_print_each_op_operands: bool = False
use_int32_for_index: bool = False

# Flash attention
Expand Down
88 changes: 69 additions & 19 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys
import contextlib
from typing import Optional, Any
Expand All @@ -14,6 +15,8 @@
from torchax import config
from torchax.ops import mappings, ops_registry

logger = logging.getLogger(__name__)


class OperatorNotFound(Exception):
pass
Expand Down Expand Up @@ -155,6 +158,16 @@ def device(self):
def jax_device(self):
return self._elem.device

@property
def data(self):
logger.warn("In-place to .data modifications still results a copy on TPU")
return self

@data.setter
def data(self, other):
if isinstance(other, Tensor):
self._elem = other._elem

def apply_jax(self, jax_function, *args, **kwargs):
# Call a jax function on _elem
res = jax_function(self._elem, *args, **kwargs)
Expand Down Expand Up @@ -199,6 +212,21 @@ def debug_accuracy(func, args, kwargs, current_output):

return True

def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
def _display(a):
if isinstance(a, torch.Tensor):
return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
elif isinstance(a, jax.Array):
return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
else:
return str(a)

kwargs = kwargs or {}
title = 'DISPATCH' if is_dispatch else 'FUNCTION'
args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'


class XLAFunctionMode(torch.overrides.TorchFunctionMode):
"""Context manager that dispatches torch function calls to JAX."""
Expand All @@ -211,7 +239,12 @@ def __torch_function__(self,
types,
args=(),
kwargs=None) -> torch.Tensor:
with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'):
message = f'FUNCTION: {_name_of_func(func)}'
if self.env.config.debug_print_each_op_operands:
message = message + 'f'
message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
func, args, kwargs)
with log_nested(self.env, message):
try:
return self.env.dispatch(func, types, args, kwargs)
except OperatorNotFound:
Expand All @@ -229,7 +262,9 @@ def __init__(self, env):
self.env = env

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'):
message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
func, args, kwargs)
with log_nested(self.env, message):
if isinstance(func, torch._ops.OpOverloadPacket):
with self:
return func(*args, **kwargs)
Expand Down Expand Up @@ -287,6 +322,8 @@ def __init__(self, configuration=None):
self._mesh = None
self.config = configuration or config.Configuration()

self._manually_entered = False
self.enabled = False
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])

def get_as_jax_device(self, device: Any):
Expand All @@ -295,16 +332,19 @@ def get_as_jax_device(self, device: Any):

if isinstance(device, torch.device):
device = str(device)
if (self.config.use_torch_native_for_cpu_tensor and
not device.startswith('jax') and not device.startswith('cuda')):
return None

if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
return None

if device == 'cpu':
if (not self.config.use_torch_native_for_cpu_tensor and
device.startswith('cpu')):
return jax.devices('cpu')[0]
return jax.local_devices()[0]

if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
return jax.local_devices()[0]

if device.startswith('jax'):
return jax.local_devices()[0]

return None # fallback to torch



def load_ops(self):
Expand All @@ -331,10 +371,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
if new_dtype is not None and new_dtype != arr.dtype:
arr = arr.astype(mappings.t2j_dtype(new_dtype))
if new_device is not None:
jax_device = self.get_as_jax_device(new_device)
if jax_device:
arr = jax.device_put(arr, jax_device)
else:
# convert xla tensor to other device
# only supported is CPU
if str(new_device).startswith('cpu'):
# converting to a non-jax device: let torch native handle it
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
Expand Down Expand Up @@ -365,7 +404,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
def _handle_tensor_constructor(self, func, args, kwargs):
device = kwargs.get('device')
jax_device = self.get_as_jax_device(device)
if jax_device is None:
# TODO(qihqi) figure out better ways for device propagation
if not self._manually_entered and jax_device is None:
# let torch handle it
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
return func(*args, **kwargs)
Expand Down Expand Up @@ -454,17 +494,27 @@ def dispatch(self, func, types, args, kwargs):
debug_accuracy(func, old_args, old_kwargs, res)
return res

def __enter__(self):
def enable_torch_modes(self):
self._dispatch_mode.__enter__()
self._function_mode.__enter__()
self.enabled = True
return self

def __exit__(self, *exc):

def disable_torch_modes(self, *exc):
if not exc:
exc = (None, None, None)
self._function_mode.__exit__(*exc)
self._dispatch_mode.__exit__(*exc)
self.enabled = False

def __enter__(self):
self.enable_torch_modes()
self._manually_entered = True
return self

def __exit__(self, *exc):
self._manually_entered = False
self.disable_torch_modes(*exc)

def _move_one_value(self, val):
if isinstance(val, torch.nn.Module):
with self:
Expand Down

0 comments on commit 8b45e59

Please sign in to comment.