Skip to content

Commit 20f5166

Browse files
authored
Use regular torch.Tensor for CPU tensors (#8416)
1 parent 39e67b5 commit 20f5166

14 files changed

+165
-117
lines changed

experimental/torch_xla2/examples/basic_training.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False):
5151
plt.imshow(npimg, cmap="Greys")
5252
else:
5353
plt.imshow(np.transpose(npimg, (1, 2, 0)))
54-
54+
#torch_xla2.env.config.debug_print_each_op = True
55+
#torch_xla2.env.config.debug_mixed_tensor = True
5556
dataiter = iter(training_loader)
5657
images, labels = next(dataiter)
5758

@@ -80,15 +81,15 @@ def forward(self, x):
8081
return x
8182

8283

83-
model = GarmentClassifier()
84+
model = GarmentClassifier().to('jax')
8485

8586
loss_fn = torch.nn.CrossEntropyLoss()
8687

8788
# NB: Loss functions expect data in batches, so we're creating batches of 4
8889
# Represents the model's confidence in each of the 10 classes for a given input
89-
dummy_outputs = torch.rand(4, 10)
90+
dummy_outputs = torch.rand(4, 10, device='jax')
9091
# Represents the correct class among the 10 being tested
91-
dummy_labels = torch.tensor([1, 5, 3, 7])
92+
dummy_labels = torch.tensor([1, 5, 3, 7], device='jax')
9293

9394
print(dummy_outputs)
9495
print(dummy_labels)
@@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None):
110111
# Every data instance is an input + label pair
111112
# NEW: Move model to XLA device
112113
inputs, labels = data
114+
inputs = inputs.to('jax')
115+
labels = labels.to('jax')
113116

114117
# Zero your gradients for every batch!
115118
optimizer.zero_grad()
@@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None):
162165
# Disable gradient computation and reduce memory consumption.
163166
with torch.no_grad():
164167
for i, vdata in enumerate(validation_loader):
165-
# NOTE: move to XLA device
168+
vinputs, vlabels = vdata
169+
vinputs = vinputs.to('jax')
170+
vlabels = vlabels.to('jax')
166171
voutputs = model(vinputs) # call model's forward
167172
vloss = loss_fn(voutputs, vlabels)
168173
running_vloss += vloss
@@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None):
172177

173178
# Log the running loss averaged per batch
174179
# for both training and validation
175-
writer.add_scalars('Training vs. Validation Loss',
176-
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
177-
epoch_number + 1)
178-
writer.flush()
179-
180-
# Track best performance, and save the model's state
181-
if avg_vloss < best_vloss:
182-
best_vloss = avg_vloss
183-
model_path = 'model_{}_{}'.format(timestamp, epoch_number)
184-
torch.save(model.state_dict(), model_path)
180+
181+
# # Track best performance, and save the model's state
182+
# if avg_vloss < best_vloss:
183+
# best_vloss = avg_vloss
184+
# model_path = 'model_{}_{}'.format(timestamp, epoch_number)
185+
# torch.save(model.state_dict(), model_path)
185186

186187
epoch_number += 1

experimental/torch_xla2/test/llama/test_llama.py

+86-86
Original file line numberDiff line numberDiff line change
@@ -12,101 +12,101 @@
1212
class LlamaTest(test_base.TestCase):
1313

1414
def test_can_run(self):
15-
sample_args = (
16-
torch.randint(0, 32000, (1, 2048)),
17-
torch.arange(0, 2048),
18-
)
19-
sample_args = pytree.tree_map(tensor.t2j, sample_args)
15+
with torch_xla2.default_env():
16+
sample_args = (
17+
torch.randint(0, 32000, (1, 2048), device='jax:0'),
18+
torch.arange(0, 2048, device='jax:0'),
19+
)
2020

21-
model_args = llama_model.ModelArgs(
22-
block_size=2048,
23-
vocab_size=32000,
24-
n_layer=2,
25-
n_head=4,
26-
dim=256,
27-
)
28-
m = llama_model.Transformer(model_args)
29-
m.to(torch.bfloat16)
30-
m.setup_caches(1, 2048)
21+
model_args = llama_model.ModelArgs(
22+
block_size=2048,
23+
vocab_size=32000,
24+
n_layer=2,
25+
n_head=4,
26+
dim=256,
27+
)
28+
m = llama_model.Transformer(model_args)
29+
m.to(torch.bfloat16)
30+
m.setup_caches(1, 2048)
31+
m = m.to('jax')
32+
33+
print(m(*sample_args))
3134

32-
# NOTE: this API does NOT use torch export
33-
weights, jax_func = torch_xla2.extract_jax(m)
34-
print(jax_func(weights, sample_args))
3535

3636
def test_can_run_exportable(self):
37-
model_args = model_exportable.ModelArgs(
38-
vocab_size=32000,
39-
n_layers=2,
40-
n_heads=4,
41-
dim=256,
42-
)
43-
m = model_exportable.Transformer(model_args)
44-
context_length = 2048
45-
input_shape_prefill = (1, context_length)
46-
input_shape_decode = (1, 1)
37+
model_args = model_exportable.ModelArgs(
38+
vocab_size=32000,
39+
n_layers=2,
40+
n_heads=4,
41+
dim=256,
42+
)
43+
m = model_exportable.Transformer(model_args)
44+
context_length = 2048
45+
input_shape_prefill = (1, context_length)
46+
input_shape_decode = (1, 1)
4747

48-
def make_cache(args, batch_size):
49-
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
50-
n_local_heads = args.n_heads
51-
n_local_kv_heads = n_kv_heads
52-
n_rep = n_local_heads // n_local_kv_heads
53-
head_dim = args.dim // args.n_heads
54-
res = []
55-
for i in range(args.n_layers):
56-
if batch_size is None:
57-
size = (
58-
args.max_seq_len,
59-
n_local_kv_heads,
60-
head_dim,
61-
)
62-
else:
63-
size = (
64-
batch_size,
65-
args.max_seq_len,
66-
n_local_kv_heads,
67-
head_dim,
68-
)
69-
res.append(
70-
(torch.zeros(
71-
size,
72-
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
73-
torch.zeros(
74-
size,
75-
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
76-
return res
48+
def make_cache(args, batch_size):
49+
n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
50+
n_local_heads = args.n_heads
51+
n_local_kv_heads = n_kv_heads
52+
n_rep = n_local_heads // n_local_kv_heads
53+
head_dim = args.dim // args.n_heads
54+
res = []
55+
for i in range(args.n_layers):
56+
if batch_size is None:
57+
size = (
58+
args.max_seq_len,
59+
n_local_kv_heads,
60+
head_dim,
61+
)
62+
else:
63+
size = (
64+
batch_size,
65+
args.max_seq_len,
66+
n_local_kv_heads,
67+
head_dim,
68+
)
69+
res.append(
70+
(torch.zeros(
71+
size,
72+
dtype=torch.bfloat16 if args.bf16_enable else torch.float),
73+
torch.zeros(
74+
size,
75+
dtype=torch.bfloat16 if args.bf16_enable else torch.float)))
76+
return res
7777

78-
prefill_caches = make_cache(model_args, 1)
78+
prefill_caches = make_cache(model_args, 1)
7979

80-
sample_input_prefill = (
81-
torch.randint(0, 1000, input_shape_prefill,
82-
dtype=torch.int32), # len seq length
83-
torch.arange(0, context_length, dtype=torch.int32), # input indexes
84-
torch.arange(0, context_length, dtype=torch.int32), # context indexes
85-
prefill_caches,
86-
True, # prefil
87-
)
88-
with torch.no_grad():
89-
m_prefill = torch.export.export(m, sample_input_prefill)
80+
sample_input_prefill = (
81+
torch.randint(0, 1000, input_shape_prefill,
82+
dtype=torch.int32), # len seq length
83+
torch.arange(0, context_length, dtype=torch.int32), # input indexes
84+
torch.arange(0, context_length, dtype=torch.int32), # context indexes
85+
prefill_caches,
86+
True, # prefil
87+
)
88+
with torch.no_grad():
89+
m_prefill = torch.export.export(m, sample_input_prefill)
9090

91-
weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
92-
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
93-
sample_input_prefill)
94-
print('Prefill', mj_prefill(weights, sample_inputs))
91+
weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill)
92+
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
93+
sample_input_prefill)
94+
print('Prefill', mj_prefill(weights, sample_inputs))
9595

96-
sample_input_decode = (
97-
torch.randint(0, 1000, input_shape_decode,
98-
dtype=torch.int32), # len = 1
99-
torch.tensor([0], dtype=torch.int32),
100-
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
101-
prefill_caches,
102-
False # prefill
103-
)
104-
with torch.no_grad():
105-
m_decode = torch.export.export(m, sample_input_decode)
106-
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
107-
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
108-
sample_input_decode)
109-
print('Decode', mj_decode(weights, sample_inputs))
96+
sample_input_decode = (
97+
torch.randint(0, 1000, input_shape_decode,
98+
dtype=torch.int32), # len = 1
99+
torch.tensor([0], dtype=torch.int32),
100+
torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0),
101+
prefill_caches,
102+
False # prefill
103+
)
104+
with torch.no_grad():
105+
m_decode = torch.export.export(m, sample_input_decode)
106+
weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode)
107+
sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j,
108+
sample_input_decode)
109+
print('Decode', mj_decode(weights, sample_inputs))
110110

111111

112112
if __name__ == "__main__":

experimental/torch_xla2/test/test_context.py

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010

1111
class TestContext(unittest.TestCase):
1212

13+
def setUp(self):
14+
self.old_var = xla_env.config.use_torch_native_for_cpu_tensor
15+
xla_env.config.use_torch_native_for_cpu_tensor = False
16+
17+
def tearDown(self):
18+
xla_env.config.use_torch_native_for_cpu_tensor = self.old_var
19+
1320
def test_mode_context_manager(self):
1421
with xla_env:
1522
x = torch.full((3, 3), -1)

experimental/torch_xla2/test/test_core_aten_ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def setUp(self):
6666
super().setUp()
6767
torch.manual_seed(0)
6868
self.env = tensor.Environment()
69+
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
70+
self.env.config.use_torch_native_for_cpu_tensor = False
71+
72+
def tearDown(self):
73+
self.env.config.use_torch_native_for_cpu_tensor = self.old_var
6974

7075
def test_aten_abs_0(self):
7176
args = (torch.randn((10, 10)).to(torch.float32),)

experimental/torch_xla2/test/test_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase):
1010

1111
def setUp(self):
1212
self.env = torch_xla2.tensor.Environment()
13+
self.env.config.use_torch_native_for_cpu_tensor = False
1314
torch_xla2.enable_accuracy_mode()
1415

1516
@parameterized.named_parameters(

experimental/torch_xla2/test/test_libraries.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import unittest
2-
import jax
32
import torch
4-
import torch.nn as nn
53
import torch.nn.functional as F
64
from torch.library import Library, impl, impl_abstract
75
import torch_xla2
8-
from torch_xla2 import tensor
6+
import torch_xla2.export
97
from torch_xla2.ops import jaten
108
from torch_xla2.ops import jlibrary
119

@@ -56,6 +54,7 @@ class LibraryTest(unittest.TestCase):
5654

5755
def setUp(self):
5856
torch.manual_seed(0)
57+
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
5958

6059
def test_basic_sdpa_library(self):
6160

@@ -78,3 +77,7 @@ def forward(self, q,k,v):
7877
## stablehlo.composite ops.
7978
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
8079
self.assertIn("call @mylib.softmax", module_str)
80+
81+
82+
if __name__ == '__main__':
83+
unittest.main()

experimental/torch_xla2/test/test_ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def setUp(self):
192192
torch_xla2.enable_accuracy_mode()
193193
#self.env.config.debug_accuracy_for_each_op = True
194194
torch.manual_seed(0)
195+
self.old_var = self.env.config.use_torch_native_for_cpu_tensor
196+
self.env.config.use_torch_native_for_cpu_tensor = False
197+
198+
def tearDown(self):
199+
self.env.config.use_torch_native_for_cpu_tensor = self.old_var
195200

196201
# Replaces all values in the input torch_tensor that are less than the given threshold
197202
# with the threshold value itself.

experimental/torch_xla2/test/test_tf_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import jax
21
import os
32
import tempfile
3+
import numpy as np
44
import tensorflow as tf
55
import torch
66
import torch.nn.functional as F

experimental/torch_xla2/test/test_unbounded_dynamism.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import sys
33
import unittest
44

5-
import numpy as np
65
import torch
76
from torch.export import Dim, export
87
from torch_xla2.export import exported_program_to_stablehlo as exp2shlo
8+
import torch_xla2
99

1010
## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py`
1111
## To test that torch_xla2 has identical behavior.
@@ -44,6 +44,14 @@ def forward(self, *args):
4444

4545
class UnboundedDynamismExportTest(unittest.TestCase):
4646

47+
def setUp(self):
48+
self.env = torch_xla2.default_env()
49+
self.env.config.use_torch_native_for_cpu_tensor = False
50+
torch_xla2.enable_accuracy_mode()
51+
52+
def tearDown(self):
53+
self.env.config.use_torch_native_for_cpu_tensor = True
54+
4755
def test_add(self):
4856
args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768)))
4957
dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),)

experimental/torch_xla2/torch_xla2/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from typing import List, Dict, Any, Optional
23
import dataclasses
34
import jax
@@ -73,6 +74,15 @@ def disable_globally():
7374
global env
7475
default_env().__exit__(None, None, None)
7576

77+
@contextlib.contextmanager
78+
def disable_temporarily():
79+
prev = default_env().enabled
80+
if prev:
81+
disable_globally()
82+
yield()
83+
if prev:
84+
enable_globally()
85+
7686

7787
torch.utils.rename_privateuse1_backend('jax')
7888
unsupported_dtype = [torch.quint8]

experimental/torch_xla2/torch_xla2/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ class Configuration:
1414

1515
# device
1616
treat_cuda_as_jax_device: bool = True
17-
use_torch_native_for_cpu_tensor: bool = False
17+
use_torch_native_for_cpu_tensor: bool = True
1818
internal_respect_torch_return_dtypes: bool = False

0 commit comments

Comments
 (0)