Skip to content

Commit c3d8a34

Browse files
committed
Rename XLATensor2 to Tensor
1 parent fbd2ac2 commit c3d8a34

18 files changed

+68
-68
lines changed

torchax/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ with env:
128128
inputs = torch.randn(3, 3, 28, 28)
129129
m = MyModel()
130130
res = m(inputs)
131-
print(type(res)) # outputs XLATensor2
131+
print(type(res)) # outputs Tensor
132132
```
133133

134134
You can also enable the environment globally with

torchax/docs/how_it_works.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ How it works
44

55
## Tensor subclass and eager mode
66

7-
The class `XLATensor2` is a `torch.Tensor` subclass
7+
The class `Tensor` is a `torch.Tensor` subclass
88
that overrides `__torch_dispatch__`.
99

1010
It roughly looks like this (with some details removed):
1111

1212
The complete class impl is at [tensor.py](../torchax/tensor.py).
1313

1414
```python
15-
class XLATensor2(torch.Tensor):
15+
class Tensor(torch.Tensor):
1616

1717
@staticmethod
1818
def __new__(cls, elem):
@@ -33,21 +33,21 @@ class XLATensor2(torch.Tensor):
3333
@classmethod
3434
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
3535
# here assumes ALL tensors in args / kwargs are
36-
# instances of XLATensor2
36+
# instances of Tensor
3737
args, kwargs = unwrap((args, kwargs))
3838
jax_func = some_registry[func]
3939
res = jax_func(*args, **kwargs)
4040
return wrap(res)
4141

4242
def wrap(tree):
43-
# wrap jax.Array with XLATensor2
43+
# wrap jax.Array with Tensor
4444
return pytree.tree_map_only(
45-
jax.Array, XLATensor2, tree)
45+
jax.Array, Tensor, tree)
4646

4747
def unwrap(tree):
48-
# get jax.Array out ofXLATensor2
48+
# get jax.Array out ofTensor
4949
return pytree.tree_map_only(
50-
XLATensor2, lambda x: x._elem, tree)
50+
Tensor, lambda x: x._elem, tree)
5151
```
5252

5353
In other words, assuming that we have a function
@@ -120,7 +120,7 @@ def backend(fxgraph):
120120
The inner function `tojit` is a function that takes and returns
121121
`jax.Array`'s. So it's suitable to be jitted with `jax.jit`.
122122

123-
`f` is returned callable that takes `XLATensor2`; so can interop with
123+
`f` is returned callable that takes `Tensor`; so can interop with
124124
other torch codes.
125125

126126
## nn.Modules and state management

torchax/docs/torch_xla2_dynamo.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ not incur any data copies in this process.
3333
Consider this following pseudocode:
3434

3535
```python
36-
class XLATensor2:
36+
class Tensor:
3737
_data: jax.Array
3838
def __torch_dispatch__(...):
3939
# do stuff with _data, get new data
40-
return XLATensor2(new_data)
40+
return Tensor(new_data)
4141

4242
def dynamo_backend(fx, sample):
4343
compiled = compile fx into graph that manipulate jax.Array.

torchax/docs/understand_jax_jit/torch_module.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ def forward(self, X):
4848
print('---- example 3 -----')
4949
# where is the jax jit?
5050

51-
# m2 is a callable that takes in XLATensor2 and returns XLATensor2
52-
# m2: (XLATensor2 -> XLATensor2)
51+
# m2 is a callable that takes in Tensor and returns Tensor
52+
# m2: (Tensor -> Tensor)
5353

54-
# suppose t2j (XLATensor2 -> jax.Array) "unwraps the XLATensor"
55-
# suppose j2t (jax.Array -> XLATensor2) "wraps the XLATensor"
54+
# suppose t2j (Tensor -> jax.Array) "unwraps the XLATensor"
55+
# suppose j2t (jax.Array -> Tensor) "wraps the XLATensor"
5656
from torchax import tensor
5757
import jax
5858

59-
def t2j(torch_tensor: tensor.XLATensor2) -> jax.Array:
59+
def t2j(torch_tensor: tensor.Tensor) -> jax.Array:
6060
return torch_tensor._elem
6161

6262

63-
def j2t(jax_array: jax.Array) -> tensor.XLATensor2:
64-
return tensor.XLATensor2(jax_array, env)
63+
def j2t(jax_array: jax.Array) -> tensor.Tensor:
64+
return tensor.Tensor(jax_array, env)
6565

6666
# # further notice t2j(j2t(x)) == x; j2t(t2j(x)) == x
6767

@@ -75,7 +75,7 @@ def jax_m(X: jax.Array):
7575
jax_x = jnp.ones((10, 1000))
7676
print(jax_m(jax_x))
7777

78-
## Let f: XLATensor2 -> XLATensor2
78+
## Let f: Tensor -> Tensor
7979
## There is a function g: jax.Array -> jax.Array;
8080
## g = x |-> j2t (f (t2j(x))). OR,
8181
## g = j2t . f . t2j (. denotes function composition)

torchax/examples/_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, model):
2626

2727

2828
def _maybe_move_tensor(self, tensor):
29-
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.XLATensor2):
29+
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor):
3030
return torchax.tensor.move_to_device(tensor)
3131
return tensor
3232

torchax/examples/mnist_tpu.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
{
305305
"data": {
306306
"text/plain": [
307-
"XLATensor2(<class 'jaxlib.xla_extension.ArrayImpl'> [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n",
307+
"Tensor(<class 'jaxlib.xla_extension.ArrayImpl'> [[ 0.03249096 0.01343462 -0.022144 ... 0.00668433 0.00833362\n",
308308
" 0.00225713]\n",
309309
" [ 0.02272127 0.02205281 0.00828168 ... -0.02310903 0.02183958\n",
310310
" 0.01084254]\n",

torchax/examples/train_gpt/train_ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
def _checkpoint(jax_model, path: pathlib.Path):
3030
torch.save(
3131
torch_pytree.tree_map_only(
32-
torchax.tensor.XLATensor2,
33-
torchax.tensor.XLATensor2.torch,
32+
torchax.tensor.Tensor,
33+
torchax.tensor.Tensor.torch,
3434
jax_model.state_dict(),
3535
),
3636
path,

torchax/test/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def run_function_and_compare(testcase,
4040
args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
4141
(args, kwargs))
4242
res2 = func(*args2, **kwargs2)
43-
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
43+
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
4444
with testcase.subTest("torchax_diff:" + str(atol)):
4545
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
4646
diff_output(

torchax/test/test_context.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def tearDown(self):
2020
def test_mode_context_manager(self):
2121
with xla_env:
2222
x = torch.full((3, 3), -1)
23-
self.assertIsInstance(x, tensor.XLATensor2)
23+
self.assertIsInstance(x, tensor.Tensor)
2424
y = x.abs()
25-
self.assertIsInstance(y, tensor.XLATensor2)
25+
self.assertIsInstance(y, tensor.Tensor)
2626

2727
@staticmethod
2828
@xla_env
@@ -34,30 +34,30 @@ def _test_mode_decorator():
3434

3535
def test_mode_decorator(self):
3636
x, y = self._test_mode_decorator()
37-
self.assertIsInstance(x, tensor.XLATensor2)
38-
self.assertIsInstance(y, tensor.XLATensor2)
37+
self.assertIsInstance(x, tensor.Tensor)
38+
self.assertIsInstance(y, tensor.Tensor)
3939

4040
def test_same_manual_seed(self):
4141
with xla_env:
4242
torch.manual_seed(1234)
4343
x = torch.randn((3, 3))
44-
self.assertIsInstance(x, tensor.XLATensor2)
44+
self.assertIsInstance(x, tensor.Tensor)
4545

4646
torch.manual_seed(1234)
4747
y = torch.randn((3, 3))
48-
self.assertIsInstance(y, tensor.XLATensor2)
48+
self.assertIsInstance(y, tensor.Tensor)
4949

5050
self.assertTrue(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))
5151

5252
def test_different_manual_seed(self):
5353
with xla_env:
5454
torch.manual_seed(1234)
5555
x = torch.randn((3, 3))
56-
self.assertIsInstance(x, tensor.XLATensor2)
56+
self.assertIsInstance(x, tensor.Tensor)
5757

5858
torch.manual_seed(12345)
5959
y = torch.randn((3, 3))
60-
self.assertIsInstance(y, tensor.XLATensor2)
60+
self.assertIsInstance(y, tensor.Tensor)
6161

6262
self.assertFalse(torch.equal(torchax.tensor.j2t(x._elem), torchax.tensor.j2t(y._elem)))
6363

@@ -69,7 +69,7 @@ def random_op():
6969
return x @ y
7070

7171
random_jit = torchax.interop.jax_jit(random_op)
72-
self.assertIsInstance(random_jit(), tensor.XLATensor2)
72+
self.assertIsInstance(random_jit(), tensor.Tensor)
7373

7474
# Result always expected to be the same for a jitted function because seeds
7575
# are baked in
@@ -101,13 +101,13 @@ def __init__(self):
101101
# Test context manager.
102102
with xla_env:
103103
m = M()
104-
self.assertIsInstance(m.c, tensor.XLATensor2)
105-
self.assertIsInstance(m.c2, tensor.XLATensor2)
104+
self.assertIsInstance(m.c, tensor.Tensor)
105+
self.assertIsInstance(m.c2, tensor.Tensor)
106106
# Test `to_xla`.
107107
m = M()
108108
m = xla_env.to_xla(m)
109-
self.assertIsInstance(m.c, tensor.XLATensor2)
110-
self.assertIsInstance(m.c2, tensor.XLATensor2)
109+
self.assertIsInstance(m.c, tensor.Tensor)
110+
self.assertIsInstance(m.c2, tensor.Tensor)
111111

112112

113113
if __name__ == "__main__":

torchax/test/test_core_aten_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run_export_and_compare(testcase,
3939
args2, kwargs2 = testcase.env.to_xla((args, kwargs))
4040
with testcase.env:
4141
res2 = func(*args2, **kwargs2)
42-
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
42+
res2 = pytree.tree_map_only(tensor.Tensor, lambda t: t.torch(), res2)
4343
# import pdb; pdb.set_trace()
4444
with testcase.subTest("torchax_diff:" + str(atol)):
4545
if ignore_indices and isinstance(res, tuple) and len(res) == 2:

0 commit comments

Comments
 (0)