@@ -207,6 +207,21 @@ def debug_accuracy(func, args, kwargs, current_output):
207
207
208
208
return True
209
209
210
+ def _make_debug_msg (is_dispatch , log_args , func , args , kwargs ):
211
+ def _display (a ):
212
+ if isinstance (a , torch .Tensor ):
213
+ return f'Tensor of { type (a )} : { a .dtype } { a .shape } '
214
+ elif isinstance (a , jax .Array ):
215
+ return f'Jax Array of { type (a )} : { a .dtype } { a .shape } '
216
+ else :
217
+ return str (a )
218
+
219
+ kwargs = kwargs or {}
220
+ title = 'DISPATCH' if is_dispatch else 'FUNCTION'
221
+ args_msg = 'args: ' + ',' .join (_display (a ) for a in args ) if log_args else ''
222
+ kwargs_msg = 'kwargs: ' + ',' .join (f'{ key } : { _display (a )} ' for key , a in kwargs .items ()) if log_args else ''
223
+ return f'{ title } : { _name_of_func (func )} { args_msg } ~ { kwargs_msg } '
224
+
210
225
211
226
class XLAFunctionMode (torch .overrides .TorchFunctionMode ):
212
227
"""Context manager that dispatches torch function calls to JAX."""
@@ -219,7 +234,12 @@ def __torch_function__(self,
219
234
types ,
220
235
args = (),
221
236
kwargs = None ) -> torch .Tensor :
222
- with log_nested (self .env , f'FUNCTION: { _name_of_func (func )} ' ):
237
+ message = f'FUNCTION: { _name_of_func (func )} '
238
+ if self .env .config .debug_print_each_op_operands :
239
+ message = message + 'f'
240
+ message = _make_debug_msg (False , self .env .config .debug_print_each_op_operands ,
241
+ func , args , kwargs )
242
+ with log_nested (self .env , message ):
223
243
try :
224
244
return self .env .dispatch (func , types , args , kwargs )
225
245
except OperatorNotFound :
@@ -237,7 +257,9 @@ def __init__(self, env):
237
257
self .env = env
238
258
239
259
def __torch_dispatch__ (self , func , types , args = (), kwargs = None ):
240
- with log_nested (self .env , f'DISPATCH: { _name_of_func (func )} ' ):
260
+ message = _make_debug_msg (True , self .env .config .debug_print_each_op_operands ,
261
+ func , args , kwargs )
262
+ with log_nested (self .env , message ):
241
263
if isinstance (func , torch ._ops .OpOverloadPacket ):
242
264
with self :
243
265
return func (* args , ** kwargs )
@@ -295,6 +317,8 @@ def __init__(self, configuration=None):
295
317
self ._mesh = None
296
318
self .config = configuration or config .Configuration ()
297
319
320
+ self ._manually_entered = False
321
+ self .enabled = False
298
322
self ._jax_devices = set (['jax' , 'jax_cpu' , 'xla' ])
299
323
300
324
def get_as_jax_device (self , device : Any ):
@@ -342,10 +366,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
342
366
if new_dtype is not None and new_dtype != arr .dtype :
343
367
arr = arr .astype (mappings .t2j_dtype (new_dtype ))
344
368
if new_device is not None :
345
- jax_device = self .get_as_jax_device (new_device )
346
- if jax_device :
347
- arr = jax .device_put (arr , jax_device )
348
- else :
369
+ # convert xla tensor to other device
370
+ # only supported is CPU
371
+ if str (new_device ).startswith ('cpu' ):
349
372
# converting to a non-jax device: let torch native handle it
350
373
torch_tensor = j2t (arr ) if isinstance (the_tensor , Tensor ) else arr
351
374
with mode_utils .no_dispatch (), torch ._C .DisableTorchFunction ():
@@ -376,7 +399,8 @@ def get_and_rotate_prng_key(self, generator: Optional[torch.Generator]=None):
376
399
def _handle_tensor_constructor (self , func , args , kwargs ):
377
400
device = kwargs .get ('device' )
378
401
jax_device = self .get_as_jax_device (device )
379
- if jax_device is None :
402
+ # TODO(qihqi) figure out better ways for device propagation
403
+ if not self ._manually_entered and jax_device is None :
380
404
# let torch handle it
381
405
with mode_utils .no_dispatch (), torch ._C .DisableTorchFunction ():
382
406
return func (* args , ** kwargs )
@@ -465,17 +489,27 @@ def dispatch(self, func, types, args, kwargs):
465
489
debug_accuracy (func , old_args , old_kwargs , res )
466
490
return res
467
491
468
- def __enter__ (self ):
492
+ def enable_torch_modes (self ):
469
493
self ._dispatch_mode .__enter__ ()
470
494
self ._function_mode .__enter__ ()
471
495
self .enabled = True
472
- return self
473
-
474
- def __exit__ (self , * exc ):
496
+
497
+ def disable_torch_modes (self , * exc ):
498
+ if not exc :
499
+ exc = (None , None , None )
475
500
self ._function_mode .__exit__ (* exc )
476
501
self ._dispatch_mode .__exit__ (* exc )
477
502
self .enabled = False
478
503
504
+ def __enter__ (self ):
505
+ self .enable_torch_modes ()
506
+ self ._manually_entered = True
507
+ return self
508
+
509
+ def __exit__ (self , * exc ):
510
+ self ._manually_entered = False
511
+ self .disable_torch_modes (* exc )
512
+
479
513
def _move_one_value (self , val ):
480
514
if isinstance (val , torch .nn .Module ):
481
515
with self :
0 commit comments