From 28a75520d080fc0e256434b8d02b928b454a4c3d Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 6 Feb 2025 01:49:56 +0000 Subject: [PATCH] support .data --- torchax/torchax/tensor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index faf9f45010a..62b92ef2350 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -1,3 +1,4 @@ +import logging import sys import contextlib from typing import Optional, Any @@ -14,6 +15,8 @@ from torchax import config from torchax.ops import mappings, ops_registry +logger = logging.getLogger(__name__) + class OperatorNotFound(Exception): pass @@ -155,6 +158,11 @@ def device(self): def jax_device(self): return self._elem.device + @property + def data(self): + logger.warn("In-place to .data modifications still results a copy on TPU") + return self + def apply_jax(self, jax_function, *args, **kwargs): # Call a jax function on _elem res = jax_function(self._elem, *args, **kwargs)