We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 01cc840 commit 28a7552Copy full SHA for 28a7552
torchax/torchax/tensor.py
@@ -1,3 +1,4 @@
1
+import logging
2
import sys
3
import contextlib
4
from typing import Optional, Any
@@ -14,6 +15,8 @@
14
15
from torchax import config
16
from torchax.ops import mappings, ops_registry
17
18
+logger = logging.getLogger(__name__)
19
+
20
21
class OperatorNotFound(Exception):
22
pass
@@ -155,6 +158,11 @@ def device(self):
155
158
def jax_device(self):
156
159
return self._elem.device
157
160
161
+ @property
162
+ def data(self):
163
+ logger.warn("In-place to .data modifications still results a copy on TPU")
164
+ return self
165
166
def apply_jax(self, jax_function, *args, **kwargs):
167
# Call a jax function on _elem
168
res = jax_function(self._elem, *args, **kwargs)
0 commit comments