Skip to content

Commit

Permalink
support .data
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 6, 2025
1 parent 01cc840 commit 28a7552
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys
import contextlib
from typing import Optional, Any
Expand All @@ -14,6 +15,8 @@
from torchax import config
from torchax.ops import mappings, ops_registry

logger = logging.getLogger(__name__)


class OperatorNotFound(Exception):
pass
Expand Down Expand Up @@ -155,6 +158,11 @@ def device(self):
def jax_device(self):
return self._elem.device

@property
def data(self):
logger.warn("In-place to .data modifications still results a copy on TPU")
return self

def apply_jax(self, jax_function, *args, **kwargs):
# Call a jax function on _elem
res = jax_function(self._elem, *args, **kwargs)
Expand Down

0 comments on commit 28a7552

Please sign in to comment.