Skip to content

Commit

Permalink
Rename XLATensor2 to Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 22, 2025
1 parent fbd2ac2 commit c3d8a34
Show file tree
Hide file tree
Showing 18 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion torchax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ with env:
inputs = torch.randn(3, 3, 28, 28)
m = MyModel()
res = m(inputs)
print(type(res)) # outputs XLATensor2
print(type(res)) # outputs Tensor
```

You can also enable the environment globally with
Expand Down
16 changes: 8 additions & 8 deletions torchax/docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ How it works

## Tensor subclass and eager mode

The class `XLATensor2` is a `torch.Tensor` subclass
The class `Tensor` is a `torch.Tensor` subclass
that overrides `__torch_dispatch__`.

It roughly looks like this (with some details removed):

The complete class impl is at [tensor.py](../torchax/tensor.py).

```python
class XLATensor2(torch.Tensor):
class Tensor(torch.Tensor):

@staticmethod
def __new__(cls, elem):
Expand All @@ -33,21 +33,21 @@ class XLATensor2(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# here assumes ALL tensors in args / kwargs are
# instances of XLATensor2
# instances of Tensor
args, kwargs = unwrap((args, kwargs))
jax_func = some_registry[func]
res = jax_func(*args, **kwargs)
return wrap(res)

def wrap(tree):
# wrap jax.Array with XLATensor2
# wrap jax.Array with Tensor
return pytree.tree_map_only(
jax.Array, XLATensor2, tree)
jax.Array, Tensor, tree)

def unwrap(tree):
# get jax.Array out ofXLATensor2
# get jax.Array out ofTensor
return pytree.tree_map_only(
XLATensor2, lambda x: x._elem, tree)
Tensor, lambda x: x._elem, tree)
```

In other words, assuming that we have a function
Expand Down Expand Up @@ -120,7 +120,7 @@ def backend(fxgraph):
The inner function `tojit` is a function that takes and returns
`jax.Array`'s. So it's suitable to be jitted with `jax.jit`.

`f` is returned callable that takes `XLATensor2`; so can interop with
`f` is returned callable that takes `Tensor`; so can interop with
other torch codes.

## nn.Modules and state management
Expand Down
4 changes: 2 additions & 2 deletions torchax/docs/torch_xla2_dynamo.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ not incur any data copies in this process.
Consider this following pseudocode:

```python
class XLATensor2:
class Tensor:
_data: jax.Array
def __torch_dispatch__(...):
# do stuff with _data, get new data
return XLATensor2(new_data)
return Tensor(new_data)

def dynamo_backend(fx, sample):
compiled = compile fx into graph that manipulate jax.Array.
Expand Down
16 changes: 8 additions & 8 deletions torchax/docs/understand_jax_jit/torch_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ def forward(self, X):
print('---- example 3 -----')
# where is the jax jit?

# m2 is a callable that takes in XLATensor2 and returns XLATensor2
# m2: (XLATensor2 -> XLATensor2)
# m2 is a callable that takes in Tensor and returns Tensor
# m2: (Tensor -> Tensor)

# suppose t2j (XLATensor2 -> jax.Array) "unwraps the XLATensor"
# suppose j2t (jax.Array -> XLATensor2) "wraps the XLATensor"
# suppose t2j (Tensor -> jax.Array) "unwraps the XLATensor"
# suppose j2t (jax.Array -> Tensor) "wraps the XLATensor"
from torchax import tensor
import jax

def t2j(torch_tensor: tensor.XLATensor2) -> jax.Array:
def t2j(torch_tensor: tensor.Tensor) -> jax.Array:
return torch_tensor._elem


def j2t(jax_array: jax.Array) -> tensor.XLATensor2:
return tensor.XLATensor2(jax_array, env)
def j2t(jax_array: jax.Array) -> tensor.Tensor:
return tensor.Tensor(jax_array, env)

# # further notice t2j(j2t(x)) == x; j2t(t2j(x)) == x

Expand All @@ -75,7 +75,7 @@ def jax_m(X: jax.Array):
jax_x = jnp.ones((10, 1000))
print(jax_m(jax_x))

## Let f: XLATensor2 -> XLATensor2
## Let f: Tensor -> Tensor
## There is a function g: jax.Array -> jax.Array;
## g = x |-> j2t (f (t2j(x))). OR,
## g = j2t . f . t2j (. denotes function composition)
Expand Down
2 changes: 1 addition & 1 deletion torchax/examples/_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, model):


def _maybe_move_tensor(self, tensor):
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.XLATensor2):
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor):
return torchax.tensor.move_to_device(tensor)
return tensor

Expand Down
2 changes: 1 addition & 1 deletion torchax/examples/mnist_tpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@
{
"data": {
"text/plain": [
"XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n",
"Tensor(<class 'jaxlib.xla_extension.ArrayImpl'> [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n",
" 0.00225713]\n",
" [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n",
" 0.01084254]\n",
Expand Down
4 changes: 2 additions & 2 deletions torchax/examples/train_gpt/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
def _checkpoint(jax_model, path: pathlib.Path):
torch.save(
torch_pytree.tree_map_only(
torchax.tensor.XLATensor2,
torchax.tensor.XLATensor2.torch,
torchax.tensor.Tensor,
torchax.tensor.Tensor.torch,
jax_model.state_dict(),
),
path,
Expand Down
2 changes: 1 addition & 1 deletion torchax/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_function_and_compare(testcase,
args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
(args, kwargs))
res2 = func(*args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
with testcase.subTest("torchax_diff:" + str(atol)):
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
diff_output(
Expand Down
26 changes: 13 additions & 13 deletions torchax/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def tearDown(self):
def test_mode_context_manager(self):
with xla_env:
x = torch.full((3, 3), -1)
self.assertIsInstance(x, tensor.XLATensor2)
self.assertIsInstance(x, tensor.Tensor)
y = x.abs()
self.assertIsInstance(y, tensor.XLATensor2)
self.assertIsInstance(y, tensor.Tensor)

@staticmethod
@xla_env
Expand All @@ -34,30 +34,30 @@ def _test_mode_decorator():

def test_mode_decorator(self):
x, y = self._test_mode_decorator()
self.assertIsInstance(x, tensor.XLATensor2)
self.assertIsInstance(y, tensor.XLATensor2)
self.assertIsInstance(x, tensor.Tensor)
self.assertIsInstance(y, tensor.Tensor)

def test_same_manual_seed(self):
with xla_env:
torch.manual_seed(1234)
x = torch.randn((3, 3))
self.assertIsInstance(x, tensor.XLATensor2)
self.assertIsInstance(x, tensor.Tensor)

torch.manual_seed(1234)
y = torch.randn((3, 3))
self.assertIsInstance(y, tensor.XLATensor2)
self.assertIsInstance(y, tensor.Tensor)

self.assertTrue(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))

def test_different_manual_seed(self):
with xla_env:
torch.manual_seed(1234)
x = torch.randn((3, 3))
self.assertIsInstance(x, tensor.XLATensor2)
self.assertIsInstance(x, tensor.Tensor)

torch.manual_seed(12345)
y = torch.randn((3, 3))
self.assertIsInstance(y, tensor.XLATensor2)
self.assertIsInstance(y, tensor.Tensor)

self.assertFalse(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))

Expand All @@ -69,7 +69,7 @@ def random_op():
return x @ y

random_jit = torchax.interop.jax_jit(random_op)
self.assertIsInstance(random_jit(), tensor.XLATensor2)
self.assertIsInstance(random_jit(), tensor.Tensor)

# Result always expected to be the same for a jitted function because seeds
# are baked in
Expand Down Expand Up @@ -101,13 +101,13 @@ def __init__(self):
# Test context manager.
with xla_env:
m = M()
self.assertIsInstance(m.c, tensor.XLATensor2)
self.assertIsInstance(m.c2, tensor.XLATensor2)
self.assertIsInstance(m.c, tensor.Tensor)
self.assertIsInstance(m.c2, tensor.Tensor)
# Test `to_xla`.
m = M()
m = xla_env.to_xla(m)
self.assertIsInstance(m.c, tensor.XLATensor2)
self.assertIsInstance(m.c2, tensor.XLATensor2)
self.assertIsInstance(m.c, tensor.Tensor)
self.assertIsInstance(m.c2, tensor.Tensor)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torchax/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run_export_and_compare(testcase,
args2, kwargs2 = testcase.env.to_xla((args, kwargs))
with testcase.env:
res2 = func(*args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
# import pdb; pdb.set_trace()
with testcase.subTest("torchax_diff:" + str(atol)):
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
Expand Down
2 changes: 1 addition & 1 deletion torchax/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):

with self.env:
actual = func()
self.assertIsInstance(actual, torchax.tensor.XLATensor2)
self.assertIsInstance(actual, torchax.tensor.Tensor)

torch.testing.assert_close(torchax.tensor.j2t(actual._elem), expected)

Expand Down
2 changes: 1 addition & 1 deletion torchax/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def run_export_and_compare(testcase,
sample_input.input, sample_input.args, sample_input.kwargs))
with testcase.env:
res2 = func(input2, *args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
with testcase.subTest("torchax_diff:" + str(atol)):
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
diff_output(
Expand Down
4 changes: 2 additions & 2 deletions torchax/test_dist/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def process_group():
def test_all_gather_tensor(multi_cpu, process_group):
device_count = multi_cpu

def f(index: torchax.tensor.XLATensor2):
def f(index: torchax.tensor.Tensor):
with torchax.default_env():
output = torch.zeros_like(index).expand(device_count)
dist.all_gather_into_tensor(output, index)
Expand All @@ -58,7 +58,7 @@ def test_all_gather_tensor_func(multi_cpu, process_group):
device_count = multi_cpu
group_ranks = process_group

def f(index: torchax.tensor.XLATensor2):
def f(index: torchax.tensor.Tensor):
return torch.distributed._functional_collectives.all_gather_tensor(
index, 0, group_ranks
)
Expand Down
8 changes: 4 additions & 4 deletions torchax/torchax/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _allgather_base(
input: torch.Tensor,
opts=...,
) -> dist.Work:
assert isinstance(input, torchax.tensor.XLATensor2)
assert isinstance(output, torchax.tensor.XLATensor2)
assert isinstance(input, torchax.tensor.Tensor)
assert isinstance(output, torchax.tensor.Tensor)
torch.distributed._functional_collectives.all_gather_tensor_inplace(
output, input, group=self
)
Expand All @@ -76,7 +76,7 @@ def allreduce(
opts: dist.AllreduceOptions = ...,
) -> dist.Work:
assert len(tensors) == 1
assert isinstance(tensors[0], torchax.tensor.XLATensor2)
assert isinstance(tensors[0], torchax.tensor.Tensor)
torch.distributed._functional_collectives.all_reduce_inplace(
tensors[0],
torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
Expand All @@ -93,7 +93,7 @@ def broadcast(
opts: dist.BroadcastOptions = ...,
) -> dist.Work:
assert len(tensors) == 1
assert isinstance(tensors[0], torchax.tensor.XLATensor2)
assert isinstance(tensors[0], torchax.tensor.Tensor)
tensors[0].copy_(
torch.distributed._functional_collectives.broadcast(
tensors[0], opts.rootRank, group=self
Expand Down
4 changes: 2 additions & 2 deletions torchax/torchax/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _torch_view(t: JaxValue) -> TorchValue:
# view it as-if it's a torch land object
if isinstance(t, jax.Array):
# TODO
return tensor.XLATensor2(t, torchax.default_env())
return tensor.Tensor(t, torchax.default_env())
if isinstance(t, type(jnp.int32)):
return tensor.t2j_type(t)
if callable(t): # t is a JaxCallable
Expand All @@ -151,7 +151,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
# t is an object from torch land
# view it as-if it's a jax land object
if isinstance(t, torch.Tensor):
assert isinstance(t, tensor.XLATensor2), type(t)
assert isinstance(t, tensor.Tensor), type(t)
return t.jax()
if isinstance(t, type(torch.int32)):
return tensor.t2j_dtype(t)
Expand Down
2 changes: 1 addition & 1 deletion torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchax.ops import jax_reimplement

# Keys are OpOverload, value is a callable that takes
# XLATensor2
# Tensor
all_ops = {}

# list all Aten ops from pytorch that does mutation
Expand Down
2 changes: 1 addition & 1 deletion torchax/torchax/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _as_tensor(data, dtype=None, device=None, env=None):
jax_res = jnp.asarray(data)
else:
jax_res = _tensor(data, dtype=dtype)
return torchax.tensor.XLATensor2(jax_res, env)
return torchax.tensor.Tensor(jax_res, env)


@register_function(torch.tensor)
Expand Down
Loading

0 comments on commit c3d8a34

Please sign in to comment.