diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 3b34b506d3a..6ddce255c02 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -1,3 +1,4 @@ +import logging import sys import contextlib from typing import Optional, Any @@ -14,6 +15,8 @@ from torchax import config from torchax.ops import mappings, ops_registry +logger = logging.getLogger(__name__) + class OperatorNotFound(Exception): pass @@ -155,6 +158,11 @@ def device(self): def jax_device(self): return self._elem.device + @property + def data(self): + logger.warn("In-place to .data modifications still results a copy on TPU") + return self + def apply_jax(self, jax_function, *args, **kwargs): # Call a jax function on _elem res = jax_function(self._elem, *args, **kwargs)