Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 28a7552

Browse files
committedFeb 6, 2025·
support .data
1 parent 01cc840 commit 28a7552

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed
 

‎torchax/torchax/tensor.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import sys
23
import contextlib
34
from typing import Optional, Any
@@ -14,6 +15,8 @@
1415
from torchax import config
1516
from torchax.ops import mappings, ops_registry
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class OperatorNotFound(Exception):
1922
pass
@@ -155,6 +158,11 @@ def device(self):
155158
def jax_device(self):
156159
return self._elem.device
157160

161+
@property
162+
def data(self):
163+
logger.warn("In-place to .data modifications still results a copy on TPU")
164+
return self
165+
158166
def apply_jax(self, jax_function, *args, **kwargs):
159167
# Call a jax function on _elem
160168
res = jax_function(self._elem, *args, **kwargs)

0 commit comments

Comments
 (0)
Please sign in to comment.