Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 8, 2025
1 parent 47573f6 commit a8b269b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 85 deletions.
160 changes: 84 additions & 76 deletions docs/source/features/stablehlo.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,85 +84,93 @@ For using `composite` we need to use the jax-centric export now. (i.e. no torch.
We are working in adding support for torch.export now.

``` python
import unittest
import torch
import torch.nn.functional as F
import torch_xla2 as tx
import torch_xla2.interop
from torch.library import Library, impl, impl_abstract
import torch_xla2
import torch_xla2.export

import jax
import jax.numpy as jnp



# We will use jax.lax.composite to accomplish this.
wrap_composite = tx.interop.torch_view(jax.lax.composite)


class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.q_proj = torch.nn.Linear(128, 128, bias=False)
self.k_proj = torch.nn.Linear(128, 128, bias=False)
self.v_proj = torch.nn.Linear(128, 128, bias=False)

self._composite_sdpa = wrap_composite(F.scaled_dot_product_attention, name="test.sdpa")

def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
attn_out = self._composite_sdpa(q, k, v, scale=0.25)
return attn_out

weights, jfunc = tx.extract_jax(M())
stablehlo = jax.export.export(jax.jit(jfunc))(
weights, jax.ShapeDtypeStruct((4, 8, 128), jnp.float32.dtype))
print(stablehlo.mlir_module())
```

The main StableHLO graph is shown below:

``` none
module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
...
%10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
%11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
return %11 : tensor<10x8x128xf32>
}
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}
```

The sdpa operation is encapsulated as a stablehlo composite call within
the main graph. The name and attributes specified in the torch.nn.Module
are propagated.

``` none
%12 = stablehlo.composite "test.sdpa" %3, %7, %11 {composite_attributes = {scale = 2.500000e-01 : f64}, decomposition = @test.sdpa} : (tensor<4x8x128xf32>, tensor<4x8x128xf32>, tensor<4x8x128xf32>) -> tensor<4x8x128xf32> loc(#loc95)
```

The reference PyTorch decomposition of the sdpa operation is captured in
a StableHLO function:

``` none
func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
// Actual implementation of the composite
...
return %11 : tensor<10x8x128xf32>
}
from torch_xla2.ops import jaten
from torch_xla2.ops import jlibrary


# Create a `mylib` library which has a basic SDPA op.
m = Library("mylib", "DEF")
m.define("scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor")

@impl(m, "scaled_dot_product_attention", "CompositeExplicitAutograd")
def _mylib_scaled_dot_product_attention(q, k, v):
"""Basic scaled dot product attention without all the flags/features."""
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=0,
is_causal=False,
scale=None,
)
return y.transpose(1, 2)

@impl_abstract("mylib::scaled_dot_product_attention")
def _mylib_scaled_dot_product_attention_meta(q, k, v):
return torch.empty_like(q)

# Register library op as a composite for export using the `@impl` method
# for a torch decomposition.
jlibrary.register_torch_composite(
"mylib.scaled_dot_product_attention",
_mylib_scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention,
torch.ops.mylib.scaled_dot_product_attention.default
)

# Also register ATen softmax as a composite for export in the `mylib` library
# using the JAX ATen decomposition from `jaten`.
jlibrary.register_jax_composite(
"mylib.softmax",
jaten._aten_softmax,
torch.ops.aten._softmax,
static_argnums=1 # Required by JAX jit
)

class LibraryTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(0)
torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False

def test_basic_sdpa_library(self):

class CustomOpExample(torch.nn.Module):
def forward(self, q,k,v):
x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
x = x + 1
return x

# Export and check for composite operations
model = CustomOpExample()
arg = torch.rand(32, 8, 128, 64)
args = (arg, arg, arg, )

exported = torch.export.export(model, args=args)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

## TODO Update this machinery from producing function calls to producing
## stablehlo.composite ops.
self.assertIn("call @mylib.scaled_dot_product_attention", module_str)
self.assertIn("call @mylib.softmax", module_str)


if __name__ == '__main__':
unittest.main()
```

As we see, to emit a stablehlo function into composite, first we make a python function
representing the region of code that we want to call, (
in this case `F.scaled_dot_product_attention` is already such function).
Then we wrap the function with `wrap_composite`.

NOTE: currently a model with `wrap_composite` call will not work with `torch.export`.
We are actively working to make it work.
representing the region of code that we want to call, then, we register it
so that pytorch and jlibrary understands it's a custom region. Then, th
emitted Stablehlo will have `mylib.scaled_dot_product_attention` and `mylib.softmax`
outlined stablehlo functions.
8 changes: 4 additions & 4 deletions experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_interpolate(self):
self.assertTrue(torch.allclose(ans, ans2, atol=1e-3))

# Convert to StableHLO
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn("func.func public @main", module_str)
self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str)
Expand All @@ -75,7 +75,7 @@ def test_constant(self):
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))

# Convert to StableHLO
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn("func.func public @main", module_str)
self.assertIn("stablehlo.divide", module_str)
Expand All @@ -89,7 +89,7 @@ def test_interpolate_dynamic(self):

with torch.no_grad():
exported = torch.export.export(model, arg, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

# Look for dynamic shape artifacts
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_export_dtypes(self):
arg = (torch.randn(10).to(torch_dtype),)
with torch.no_grad():
exported = torch.export.export(model, arg)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())
self.assertIn(DTYPE_TO_MLIR_STR[torch_dtype], module_str)

Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_libraries.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, q,k,v):
args = (arg, arg, arg, )

exported = torch.export.export(model, args=args)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

## TODO Update this machinery from producing function calls to producing
Expand Down
6 changes: 3 additions & 3 deletions experimental/torch_xla2/test/test_symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_constraints_min_max(self):

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"stablehlo.constant.*3")
Expand All @@ -62,7 +62,7 @@ def test_constraints_multiply(self):

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"stablehlo.constant.*10")
Expand All @@ -84,7 +84,7 @@ def test_constraint_indirection(self):

with torch.no_grad():
exported = torch.export.export(model, args=args, dynamic_shapes=dynamic_shapes)
stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
weights, stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
module_str = str(stablehlo.mlir_module())

self.assertRegex(module_str, r"shape_assertion.*s[0-9]+ <= 10")
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_stablehlo_text(self):
return self.export.mlir_module()

def exported_program_to_stablehlo(exported):
return ExportAdapter(exp2shlo(exported))
return ExportAdapter(exp2shlo(exported)[1])

def wrap_func_as_nn_module(f):
class M(torch.nn.Module):
Expand Down

0 comments on commit a8b269b

Please sign in to comment.