Skip to content

Commit 8b45e59

Browse files
authored
Misc fixes (#8679)
better handle meta device .data obj linear should check bias with is not None
1 parent 45109d1 commit 8b45e59

File tree

5 files changed

+86
-23
lines changed

5 files changed

+86
-23
lines changed

torchax/test/test_functions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def test_rnn(self):
7070
model = SeqModel()
7171
x = torch.randn((2, 100, 20))
7272
res = model(x)
73-
self.env.config.debug_print_each_op = True
7473
with self.env:
7574
model.to('jax')
7675
x = x.to('jax')
@@ -79,6 +78,18 @@ def test_rnn(self):
7978

8079
self.assertEqual(res.shape, res2.shape)
8180

81+
def test_rms_norm(self):
82+
model = torch.nn.RMSNorm((100, 20))
83+
x = torch.randn((2, 100, 20))
84+
res = model(x)
85+
86+
with self.env:
87+
model.to('jax')
88+
x = x.to('jax')
89+
res2 = model(x)
90+
self.assertTrue(
91+
torch.allclose(res, torchax.tensor.j2t(res2.jax())))
92+
8293

8394

8495

torchax/test/test_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def setUp(self):
184184
self.env = torchax.default_env()
185185
torchax.enable_accuracy_mode()
186186
#self.env.config.debug_accuracy_for_each_op = True
187+
self.env.config.debug_print_each_op = True
188+
self.env.config.debug_print_each_op_operands = True
187189
torch.manual_seed(0)
188190
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
189191
self.env.config.use_torch_native_for_cpu_tensor = False

torchax/torchax/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,12 @@ def jax_func(states, inputs):
6666
return states, jax_func
6767

6868
def enable_globally():
69-
global env
70-
env = default_env().__enter__()
69+
env = default_env().enable_torch_modes()
7170
return env
7271

7372
def disable_globally():
7473
global env
75-
default_env().__exit__(None, None, None)
74+
default_env().disable_torch_modes()
7675

7776
@contextlib.contextmanager
7877
def disable_temporarily():

torchax/torchax/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class Configuration:
66
debug_print_each_op: bool = False
77
debug_accuracy_for_each_op: bool = False
88
debug_mixed_tensor: bool = False
9+
debug_print_each_op_operands: bool = False
910
use_int32_for_index: bool = False
1011

1112
# Flash attention

torchax/torchax/tensor.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import sys
23
import contextlib
34
from typing import Optional, Any
@@ -14,6 +15,8 @@
1415
from torchax import config
1516
from torchax.ops import mappings, ops_registry
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class OperatorNotFound(Exception):
1922
pass
@@ -155,6 +158,16 @@ def device(self):
155158
def jax_device(self):
156159
return self._elem.device
157160

161+
@property
162+
def data(self):
163+
logger.warn("In-place to .data modifications still results a copy on TPU")
164+
return self
165+
166+
@data.setter
167+
def data(self, other):
168+
if isinstance(other, Tensor):
169+
self._elem = other._elem
170+
158171
def apply_jax(self, jax_function, *args, **kwargs):
159172
# Call a jax function on _elem
160173
res = jax_function(self._elem, *args, **kwargs)
@@ -199,6 +212,21 @@ def debug_accuracy(func, args, kwargs, current_output):
199212

200213
return True
201214

215+
def _make_debug_msg(is_dispatch, log_args, func, args, kwargs):
216+
def _display(a):
217+
if isinstance(a, torch.Tensor):
218+
return f'Tensor of {type(a)}: {a.dtype}{a.shape}'
219+
elif isinstance(a, jax.Array):
220+
return f'Jax Array of {type(a)}: {a.dtype}{a.shape}'
221+
else:
222+
return str(a)
223+
224+
kwargs = kwargs or {}
225+
title = 'DISPATCH' if is_dispatch else 'FUNCTION'
226+
args_msg = 'args: ' + ','.join(_display(a) for a in args) if log_args else ''
227+
kwargs_msg = 'kwargs: ' + ','.join(f'{key}: {_display(a)}' for key, a in kwargs.items()) if log_args else ''
228+
return f'{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}'
229+
202230

203231
class XLAFunctionMode(torch.overrides.TorchFunctionMode):
204232
"""Context manager that dispatches torch function calls to JAX."""
@@ -211,7 +239,12 @@ def __torch_function__(self,
211239
types,
212240
args=(),
213241
kwargs=None) -> torch.Tensor:
214-
with log_nested(self.env, f'FUNCTION: {_name_of_func(func)}'):
242+
message = f'FUNCTION: {_name_of_func(func)}'
243+
if self.env.config.debug_print_each_op_operands:
244+
message = message + 'f'
245+
message = _make_debug_msg(False, self.env.config.debug_print_each_op_operands,
246+
func, args, kwargs)
247+
with log_nested(self.env, message):
215248
try:
216249
return self.env.dispatch(func, types, args, kwargs)
217250
except OperatorNotFound:
@@ -229,7 +262,9 @@ def __init__(self, env):
229262
self.env = env
230263

231264
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
232-
with log_nested(self.env, f'DISPATCH: {_name_of_func(func)}'):
265+
message = _make_debug_msg(True, self.env.config.debug_print_each_op_operands,
266+
func, args, kwargs)
267+
with log_nested(self.env, message):
233268
if isinstance(func, torch._ops.OpOverloadPacket):
234269
with self:
235270
return func(*args, **kwargs)
@@ -287,6 +322,8 @@ def __init__(self, configuration=None):
287322
self._mesh = None
288323
self.config = configuration or config.Configuration()
289324

325+
self._manually_entered = False
326+
self.enabled = False
290327
self._jax_devices = set(['jax', 'jax_cpu', 'xla'])
291328

292329
def get_as_jax_device(self, device: Any):
@@ -295,16 +332,19 @@ def get_as_jax_device(self, device: Any):
295332

296333
if isinstance(device, torch.device):
297334
device = str(device)
298-
if (self.config.use_torch_native_for_cpu_tensor and
299-
not device.startswith('jax') and not device.startswith('cuda')):
300-
return None
301335

302-
if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
303-
return None
304-
305-
if device == 'cpu':
336+
if (not self.config.use_torch_native_for_cpu_tensor and
337+
device.startswith('cpu')):
306338
return jax.devices('cpu')[0]
307-
return jax.local_devices()[0]
339+
340+
if self.config.treat_cuda_as_jax_device and device.startswith('cuda'):
341+
return jax.local_devices()[0]
342+
343+
if device.startswith('jax'):
344+
return jax.local_devices()[0]
345+
346+
return None # fallback to torch
347+
308348

309349

310350
def load_ops(self):
@@ -331,10 +371,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
331371
if new_dtype is not None and new_dtype != arr.dtype:
332372
arr = arr.astype(mappings.t2j_dtype(new_dtype))
333373
if new_device is not None:
334-
jax_device = self.get_as_jax_device(new_device)
335-
if jax_device:
336-
arr = jax.device_put(arr, jax_device)
337-
else:
374+
# convert xla tensor to other device
375+
# only supported is CPU
376+
if str(new_device).startswith('cpu'):
338377
# converting to a non-jax device: let torch native handle it
339378
torch_tensor = j2t(arr) if isinstance(the_tensor, Tensor) else arr
340379
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
@@ -365,7 +404,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
365404
def _handle_tensor_constructor(self, func, args, kwargs):
366405
device = kwargs.get('device')
367406
jax_device = self.get_as_jax_device(device)
368-
if jax_device is None:
407+
# TODO(qihqi) figure out better ways for device propagation
408+
if not self._manually_entered and jax_device is None:
369409
# let torch handle it
370410
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
371411
return func(*args, **kwargs)
@@ -454,17 +494,27 @@ def dispatch(self, func, types, args, kwargs):
454494
debug_accuracy(func, old_args, old_kwargs, res)
455495
return res
456496

457-
def __enter__(self):
497+
def enable_torch_modes(self):
458498
self._dispatch_mode.__enter__()
459499
self._function_mode.__enter__()
460500
self.enabled = True
461-
return self
462-
463-
def __exit__(self, *exc):
501+
502+
def disable_torch_modes(self, *exc):
503+
if not exc:
504+
exc = (None, None, None)
464505
self._function_mode.__exit__(*exc)
465506
self._dispatch_mode.__exit__(*exc)
466507
self.enabled = False
467508

509+
def __enter__(self):
510+
self.enable_torch_modes()
511+
self._manually_entered = True
512+
return self
513+
514+
def __exit__(self, *exc):
515+
self._manually_entered = False
516+
self.disable_torch_modes(*exc)
517+
468518
def _move_one_value(self, val):
469519
if isinstance(val, torch.nn.Module):
470520
with self:

0 commit comments

Comments
 (0)