1
+ import logging
1
2
import sys
2
3
import contextlib
3
4
from typing import Optional , Any
14
15
from torchax import config
15
16
from torchax .ops import mappings , ops_registry
16
17
18
+ logger = logging .getLogger (__name__ )
19
+
17
20
18
21
class OperatorNotFound (Exception ):
19
22
pass
@@ -155,6 +158,16 @@ def device(self):
155
158
def jax_device (self ):
156
159
return self ._elem .device
157
160
161
+ @property
162
+ def data (self ):
163
+ logger .warn ("In-place to .data modifications still results a copy on TPU" )
164
+ return self
165
+
166
+ @data .setter
167
+ def data (self , other ):
168
+ if isinstance (other , Tensor ):
169
+ self ._elem = other ._elem
170
+
158
171
def apply_jax (self , jax_function , * args , ** kwargs ):
159
172
# Call a jax function on _elem
160
173
res = jax_function (self ._elem , * args , ** kwargs )
@@ -199,6 +212,21 @@ def debug_accuracy(func, args, kwargs, current_output):
199
212
200
213
return True
201
214
215
+ def _make_debug_msg (is_dispatch , log_args , func , args , kwargs ):
216
+ def _display (a ):
217
+ if isinstance (a , torch .Tensor ):
218
+ return f'Tensor of { type (a )} : { a .dtype } { a .shape } '
219
+ elif isinstance (a , jax .Array ):
220
+ return f'Jax Array of { type (a )} : { a .dtype } { a .shape } '
221
+ else :
222
+ return str (a )
223
+
224
+ kwargs = kwargs or {}
225
+ title = 'DISPATCH' if is_dispatch else 'FUNCTION'
226
+ args_msg = 'args: ' + ',' .join (_display (a ) for a in args ) if log_args else ''
227
+ kwargs_msg = 'kwargs: ' + ',' .join (f'{ key } : { _display (a )} ' for key , a in kwargs .items ()) if log_args else ''
228
+ return f'{ title } : { _name_of_func (func )} { args_msg } ~ { kwargs_msg } '
229
+
202
230
203
231
class XLAFunctionMode (torch .overrides .TorchFunctionMode ):
204
232
"""Context manager that dispatches torch function calls to JAX."""
@@ -211,7 +239,12 @@ def __torch_function__(self,
211
239
types ,
212
240
args = (),
213
241
kwargs = None ) -> torch .Tensor :
214
- with log_nested (self .env , f'FUNCTION: { _name_of_func (func )} ' ):
242
+ message = f'FUNCTION: { _name_of_func (func )} '
243
+ if self .env .config .debug_print_each_op_operands :
244
+ message = message + 'f'
245
+ message = _make_debug_msg (False , self .env .config .debug_print_each_op_operands ,
246
+ func , args , kwargs )
247
+ with log_nested (self .env , message ):
215
248
try :
216
249
return self .env .dispatch (func , types , args , kwargs )
217
250
except OperatorNotFound :
@@ -229,7 +262,9 @@ def __init__(self, env):
229
262
self .env = env
230
263
231
264
def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
232
- with log_nested (self .env , f'DISPATCH: { _name_of_func (func )} ' ):
265
+ message = _make_debug_msg (True , self .env .config .debug_print_each_op_operands ,
266
+ func , args , kwargs )
267
+ with log_nested (self .env , message ):
233
268
if isinstance (func , torch ._ops .OpOverloadPacket ):
234
269
with self :
235
270
return func (* args , ** kwargs )
@@ -287,6 +322,8 @@ def __init__(self, configuration=None):
287
322
self ._mesh = None
288
323
self .config = configuration or config .Configuration ()
289
324
325
+ self ._manually_entered = False
326
+ self .enabled = False
290
327
self ._jax_devices = set (['jax' , 'jax_cpu' , 'xla' ])
291
328
292
329
def get_as_jax_device (self , device : Any ):
@@ -295,16 +332,19 @@ def get_as_jax_device(self, device: Any):
295
332
296
333
if isinstance (device , torch .device ):
297
334
device = str (device )
298
- if (self .config .use_torch_native_for_cpu_tensor and
299
- not device .startswith ('jax' ) and not device .startswith ('cuda' )):
300
- return None
301
335
302
- if not self .config .treat_cuda_as_jax_device and device .startswith ('cuda' ):
303
- return None
304
-
305
- if device == 'cpu' :
336
+ if (not self .config .use_torch_native_for_cpu_tensor and
337
+ device .startswith ('cpu' )):
306
338
return jax .devices ('cpu' )[0 ]
307
- return jax .local_devices ()[0 ]
339
+
340
+ if self .config .treat_cuda_as_jax_device and device .startswith ('cuda' ):
341
+ return jax .local_devices ()[0 ]
342
+
343
+ if device .startswith ('jax' ):
344
+ return jax .local_devices ()[0 ]
345
+
346
+ return None # fallback to torch
347
+
308
348
309
349
310
350
def load_ops (self ):
@@ -331,10 +371,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
331
371
if new_dtype is not None and new_dtype != arr .dtype :
332
372
arr = arr .astype (mappings .t2j_dtype (new_dtype ))
333
373
if new_device is not None :
334
- jax_device = self .get_as_jax_device (new_device )
335
- if jax_device :
336
- arr = jax .device_put (arr , jax_device )
337
- else :
374
+ # convert xla tensor to other device
375
+ # only supported is CPU
376
+ if str (new_device ).startswith ('cpu' ):
338
377
# converting to a non-jax device: let torch native handle it
339
378
torch_tensor = j2t (arr ) if isinstance (the_tensor , Tensor ) else arr
340
379
with mode_utils .no_dispatch (), torch ._C .DisableTorchFunction ():
@@ -365,7 +404,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
365
404
def _handle_tensor_constructor (self , func , args , kwargs ):
366
405
device = kwargs .get ('device' )
367
406
jax_device = self .get_as_jax_device (device )
368
- if jax_device is None :
407
+ # TODO(qihqi) figure out better ways for device propagation
408
+ if not self ._manually_entered and jax_device is None :
369
409
# let torch handle it
370
410
with mode_utils .no_dispatch (), torch ._C .DisableTorchFunction ():
371
411
return func (* args , ** kwargs )
@@ -454,17 +494,27 @@ def dispatch(self, func, types, args, kwargs):
454
494
debug_accuracy (func , old_args , old_kwargs , res )
455
495
return res
456
496
457
- def __enter__ (self ):
497
+ def enable_torch_modes (self ):
458
498
self ._dispatch_mode .__enter__ ()
459
499
self ._function_mode .__enter__ ()
460
500
self .enabled = True
461
- return self
462
-
463
- def __exit__ (self , * exc ):
501
+
502
+ def disable_torch_modes (self , * exc ):
503
+ if not exc :
504
+ exc = (None , None , None )
464
505
self ._function_mode .__exit__ (* exc )
465
506
self ._dispatch_mode .__exit__ (* exc )
466
507
self .enabled = False
467
508
509
+ def __enter__ (self ):
510
+ self .enable_torch_modes ()
511
+ self ._manually_entered = True
512
+ return self
513
+
514
+ def __exit__ (self , * exc ):
515
+ self ._manually_entered = False
516
+ self .disable_torch_modes (* exc )
517
+
468
518
def _move_one_value (self , val ):
469
519
if isinstance (val , torch .nn .Module ):
470
520
with self :
0 commit comments