Skip to content

Commit a8b269b

Browse files
committed
fix tests
1 parent 47573f6 commit a8b269b

File tree

5 files changed

+93
-85
lines changed

5 files changed

+93
-85
lines changed

docs/source/features/stablehlo.md

Lines changed: 84 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -84,85 +84,93 @@ For using `composite` we need to use the jax-centric export now. (i.e. no torch.
8484
We are working in adding support for torch.export now.
8585

8686
``` python
87+
import unittest
8788
import torch
8889
import torch.nn.functional as F
89-
import torch_xla2 as tx
90-
import torch_xla2.interop
90+
from torch.library import Library, impl, impl_abstract
91+
import torch_xla2
9192
import torch_xla2.export
92-
93-
import jax
94-
import jax.numpy as jnp
95-
96-
97-
98-
# We will use jax.lax.composite to accomplish this.
99-
wrap_composite = tx.interop.torch_view(jax.lax.composite)
100-
101-
102-
class M(torch.nn.Module):
103-
104-
def __init__(self):
105-
super().__init__()
106-
self.q_proj = torch.nn.Linear(128, 128, bias=False)
107-
self.k_proj = torch.nn.Linear(128, 128, bias=False)
108-
self.v_proj = torch.nn.Linear(128, 128, bias=False)
109-
110-
self._composite_sdpa = wrap_composite(F.scaled_dot_product_attention, name="test.sdpa")
111-
112-
def forward(self, x):
113-
q = self.q_proj(x)
114-
k = self.k_proj(x)
115-
v = self.v_proj(x)
116-
attn_out = self._composite_sdpa(q, k, v, scale=0.25)
117-
return attn_out
118-
119-
weights, jfunc = tx.extract_jax(M())
120-
stablehlo = jax.export.export(jax.jit(jfunc))(
121-
weights, jax.ShapeDtypeStruct((4, 8, 128), jnp.float32.dtype))
122-
print(stablehlo.mlir_module())
123-
```
124-
125-
The main StableHLO graph is shown below:
126-
127-
``` none
128-
module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
129-
func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
130-
...
131-
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
132-
%11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
133-
return %11 : tensor<10x8x128xf32>
134-
}
135-
136-
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
137-
// Actual implementation of the composite
138-
...
139-
return %11 : tensor<10x8x128xf32>
140-
}
141-
```
142-
143-
The sdpa operation is encapsulated as a stablehlo composite call within
144-
the main graph. The name and attributes specified in the torch.nn.Module
145-
are propagated.
146-
147-
``` none
148-
%12 = stablehlo.composite "test.sdpa" %3, %7, %11 {composite_attributes = {scale = 2.500000e-01 : f64}, decomposition = @test.sdpa} : (tensor<4x8x128xf32>, tensor<4x8x128xf32>, tensor<4x8x128xf32>) -> tensor<4x8x128xf32> loc(#loc95)
149-
```
150-
151-
The reference PyTorch decomposition of the sdpa operation is captured in
152-
a StableHLO function:
153-
154-
``` none
155-
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
156-
// Actual implementation of the composite
157-
...
158-
return %11 : tensor<10x8x128xf32>
159-
}
93+
from torch_xla2.ops import jaten
94+
from torch_xla2.ops import jlibrary
95+
96+
97+
# Create a `mylib` library which has a basic SDPA op.
98+
m = Library("mylib", "DEF")
99+
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")
100+
101+
@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
102+
def _mylib_scaled_dot_product_attention(q, k, v):
103+
"""Basic scaled dot product attention without all the flags/features."""
104+
q = q.transpose(1, 2)
105+
k = k.transpose(1, 2)
106+
v = v.transpose(1, 2)
107+
y = F.scaled_dot_product_attention(
108+
q,
109+
k,
110+
v,
111+
dropout_p=0,
112+
is_causal=False,
113+
scale=None,
114+
)
115+
return y.transpose(1, 2)
116+
117+
@impl_abstract("mylib::scaled_dot_product_attention")
118+
def _mylib_scaled_dot_product_attention_meta(q, k, v):
119+
return torch.empty_like(q)
120+
121+
# Register library op as a composite for export using the `@impl` method
122+
# for a torch decomposition.
123+
jlibrary.register_torch_composite(
124+
"mylib.scaled_dot_product_attention",
125+
_mylib_scaled_dot_product_attention,
126+
torch.ops.mylib.scaled_dot_product_attention,
127+
torch.ops.mylib.scaled_dot_product_attention.default
128+
)
129+
130+
# Also register ATen softmax as a composite for export in the `mylib` library
131+
# using the JAX ATen decomposition from `jaten`.
132+
jlibrary.register_jax_composite(
133+
"mylib.softmax",
134+
jaten._aten_softmax,
135+
torch.ops.aten._softmax,
136+
static_argnums=1 # Required by JAX jit
137+
)
138+
139+
class LibraryTest(unittest.TestCase):
140+
141+
def setUp(self):
142+
torch.manual_seed(0)
143+
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
144+
145+
def test_basic_sdpa_library(self):
146+
147+
class CustomOpExample(torch.nn.Module):
148+
def forward(self, q,k,v):
149+
x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
150+
x = x + 1
151+
return x
152+
153+
# Export and check for composite operations
154+
model = CustomOpExample()
155+
arg = torch.rand(32, 8, 128, 64)
156+
args = (arg, arg, arg, )
157+
158+
exported = torch.export.export(model, args=args)
159+
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
160+
module_str = str(stablehlo.mlir_module())
161+
162+
## TODO Update this machinery from producing function calls to producing
163+
## stablehlo.composite ops.
164+
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
165+
self.assertIn("call @mylib.softmax", module_str)
166+
167+
168+
if __name__ == '__main__':
169+
unittest.main()
160170
```
161171

162172
As we see, to emit a stablehlo function into composite, first we make a python function
163-
representing the region of code that we want to call, (
164-
in this case `F.scaled_dot_product_attention` is already such function).
165-
Then we wrap the function with `wrap_composite`.
166-
167-
NOTE: currently a model with `wrap_composite` call will not work with `torch.export`.
168-
We are actively working to make it work.
173+
representing the region of code that we want to call, then, we register it
174+
so that pytorch and jlibrary understands it's a custom region. Then, th
175+
emitted Stablehlo will have `mylib.scaled_dot_product_attention` and `mylib.softmax`
176+
outlined stablehlo functions.

experimental/torch_xla2/test/test_exports.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_interpolate(self):
5252
self.assertTrue(torch.allclose(ans, ans2, atol=1e-3))
5353

5454
# Convert to StableHLO
55-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
55+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
5656
module_str = str(stablehlo.mlir_module())
5757
self.assertIn("func.func public @main", module_str)
5858
self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str)
@@ -75,7 +75,7 @@ def test_constant(self):
7575
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))
7676

7777
# Convert to StableHLO
78-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
78+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
7979
module_str = str(stablehlo.mlir_module())
8080
self.assertIn("func.func public @main", module_str)
8181
self.assertIn("stablehlo.divide", module_str)
@@ -89,7 +89,7 @@ def test_interpolate_dynamic(self):
8989

9090
with torch.no_grad():
9191
exported = torch.export.export(model, arg, dynamic_shapes=dynamic_shapes)
92-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
92+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
9393
module_str = str(stablehlo.mlir_module())
9494

9595
# Look for dynamic shape artifacts
@@ -139,7 +139,7 @@ def test_export_dtypes(self):
139139
arg = (torch.randn(10).to(torch_dtype),)
140140
with torch.no_grad():
141141
exported = torch.export.export(model, arg)
142-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
142+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
143143
module_str = str(stablehlo.mlir_module())
144144
self.assertIn(DTYPE_TO_MLIR_STR[torch_dtype], module_str)
145145

experimental/torch_xla2/test/test_libraries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forward(self, q,k,v):
7070
args = (arg, arg, arg, )
7171

7272
exported = torch.export.export(model, args=args)
73-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
73+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
7474
module_str = str(stablehlo.mlir_module())
7575

7676
## TODO Update this machinery from producing function calls to producing

experimental/torch_xla2/test/test_symbolic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_constraints_min_max(self):
4040

4141
with torch.no_grad():
4242
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
43-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
43+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
4444
module_str = str(stablehlo.mlir_module())
4545

4646
self.assertRegex(module_str, r"stablehlo.constant.*3")
@@ -62,7 +62,7 @@ def test_constraints_multiply(self):
6262

6363
with torch.no_grad():
6464
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
65-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
65+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
6666
module_str = str(stablehlo.mlir_module())
6767

6868
self.assertRegex(module_str, r"stablehlo.constant.*10")
@@ -84,7 +84,7 @@ def test_constraint_indirection(self):
8484

8585
with torch.no_grad():
8686
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
87-
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
87+
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
8888
module_str = str(stablehlo.mlir_module())
8989

9090
self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10")

experimental/torch_xla2/test/test_unbounded_dynamism.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_stablehlo_text(self):
3131
return self.export.mlir_module()
3232

3333
def exported_program_to_stablehlo(exported):
34-
return ExportAdapter(exp2shlo(exported))
34+
return ExportAdapter(exp2shlo(exported)[1])
3535

3636
def wrap_func_as_nn_module(f):
3737
class M(torch.nn.Module):

0 commit comments

Comments
 (0)