@@ -84,85 +84,93 @@ For using `composite` we need to use the jax-centric export now. (i.e. no torch.
84
84
We are working in adding support for torch.export now.
85
85
86
86
``` python
87
+ import unittest
87
88
import torch
88
89
import torch.nn.functional as F
89
- import torch_xla2 as tx
90
- import torch_xla2.interop
90
+ from torch.library import Library, impl, impl_abstract
91
+ import torch_xla2
91
92
import torch_xla2.export
92
-
93
- import jax
94
- import jax.numpy as jnp
95
-
96
-
97
-
98
- # We will use jax.lax.composite to accomplish this.
99
- wrap_composite = tx.interop.torch_view(jax.lax.composite)
100
-
101
-
102
- class M (torch .nn .Module ):
103
-
104
- def __init__ (self ):
105
- super ().__init__ ()
106
- self .q_proj = torch.nn.Linear(128 , 128 , bias = False )
107
- self .k_proj = torch.nn.Linear(128 , 128 , bias = False )
108
- self .v_proj = torch.nn.Linear(128 , 128 , bias = False )
109
-
110
- self ._composite_sdpa = wrap_composite(F.scaled_dot_product_attention, name = " test.sdpa" )
111
-
112
- def forward (self , x ):
113
- q = self .q_proj(x)
114
- k = self .k_proj(x)
115
- v = self .v_proj(x)
116
- attn_out = self ._composite_sdpa(q, k, v, scale = 0.25 )
117
- return attn_out
118
-
119
- weights, jfunc = tx.extract_jax(M())
120
- stablehlo = jax.export.export(jax.jit(jfunc))(
121
- weights, jax.ShapeDtypeStruct((4 , 8 , 128 ), jnp.float32.dtype))
122
- print (stablehlo.mlir_module())
123
- ```
124
-
125
- The main StableHLO graph is shown below:
126
-
127
- ``` none
128
- module @IrToHlo.56 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
129
- func.func @main(%arg0: tensor<10x8x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<10x8x128xf32> {
130
- ...
131
- %10 = stablehlo.composite "test.sdpa" %3, %6, %9 {composite_attributes = {other_attr = "val", scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl} : (tensor<10x8x128xf32>, tensor<10x8x128xf32>, tensor<10x8x128xf32>) -> tensor<10x8x128xf32>
132
- %11 = stablehlo.add %10, %arg0 : tensor<10x8x128xf32>
133
- return %11 : tensor<10x8x128xf32>
134
- }
135
-
136
- func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
137
- // Actual implementation of the composite
138
- ...
139
- return %11 : tensor<10x8x128xf32>
140
- }
141
- ```
142
-
143
- The sdpa operation is encapsulated as a stablehlo composite call within
144
- the main graph. The name and attributes specified in the torch.nn.Module
145
- are propagated.
146
-
147
- ``` none
148
- %12 = stablehlo.composite "test.sdpa" %3, %7, %11 {composite_attributes = {scale = 2.500000e-01 : f64}, decomposition = @test.sdpa} : (tensor<4x8x128xf32>, tensor<4x8x128xf32>, tensor<4x8x128xf32>) -> tensor<4x8x128xf32> loc(#loc95)
149
- ```
150
-
151
- The reference PyTorch decomposition of the sdpa operation is captured in
152
- a StableHLO function:
153
-
154
- ``` none
155
- func.func private @test.sdpa.impl(%arg0: tensor<10x8x128xf32>, %arg1: tensor<10x8x128xf32>, %arg2: tensor<10x8x128xf32>) -> tensor<10x8x128xf32> {
156
- // Actual implementation of the composite
157
- ...
158
- return %11 : tensor<10x8x128xf32>
159
- }
93
+ from torch_xla2.ops import jaten
94
+ from torch_xla2.ops import jlibrary
95
+
96
+
97
+ # Create a `mylib` library which has a basic SDPA op.
98
+ m = Library(" mylib" , " DEF" )
99
+ m.define(" scaled_dot_product_attention(Tensor q, Tensor k, Tensor v) -> Tensor" )
100
+
101
+ @impl (m, " scaled_dot_product_attention" , " CompositeExplicitAutograd" )
102
+ def _mylib_scaled_dot_product_attention (q , k , v ):
103
+ """ Basic scaled dot product attention without all the flags/features."""
104
+ q = q.transpose(1 , 2 )
105
+ k = k.transpose(1 , 2 )
106
+ v = v.transpose(1 , 2 )
107
+ y = F.scaled_dot_product_attention(
108
+ q,
109
+ k,
110
+ v,
111
+ dropout_p = 0 ,
112
+ is_causal = False ,
113
+ scale = None ,
114
+ )
115
+ return y.transpose(1 , 2 )
116
+
117
+ @impl_abstract (" mylib::scaled_dot_product_attention" )
118
+ def _mylib_scaled_dot_product_attention_meta (q , k , v ):
119
+ return torch.empty_like(q)
120
+
121
+ # Register library op as a composite for export using the `@impl` method
122
+ # for a torch decomposition.
123
+ jlibrary.register_torch_composite(
124
+ " mylib.scaled_dot_product_attention" ,
125
+ _mylib_scaled_dot_product_attention,
126
+ torch.ops.mylib.scaled_dot_product_attention,
127
+ torch.ops.mylib.scaled_dot_product_attention.default
128
+ )
129
+
130
+ # Also register ATen softmax as a composite for export in the `mylib` library
131
+ # using the JAX ATen decomposition from `jaten`.
132
+ jlibrary.register_jax_composite(
133
+ " mylib.softmax" ,
134
+ jaten._aten_softmax,
135
+ torch.ops.aten._softmax,
136
+ static_argnums = 1 # Required by JAX jit
137
+ )
138
+
139
+ class LibraryTest (unittest .TestCase ):
140
+
141
+ def setUp (self ):
142
+ torch.manual_seed(0 )
143
+ torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False
144
+
145
+ def test_basic_sdpa_library (self ):
146
+
147
+ class CustomOpExample (torch .nn .Module ):
148
+ def forward (self , q ,k ,v ):
149
+ x = torch.ops.mylib.scaled_dot_product_attention(q, k, v)
150
+ x = x + 1
151
+ return x
152
+
153
+ # Export and check for composite operations
154
+ model = CustomOpExample()
155
+ arg = torch.rand(32 , 8 , 128 , 64 )
156
+ args = (arg, arg, arg, )
157
+
158
+ exported = torch.export.export(model, args = args)
159
+ stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
160
+ module_str = str (stablehlo.mlir_module())
161
+
162
+ # # TODO Update this machinery from producing function calls to producing
163
+ # # stablehlo.composite ops.
164
+ self .assertIn(" call @mylib.scaled_dot_product_attention" , module_str)
165
+ self .assertIn(" call @mylib.softmax" , module_str)
166
+
167
+
168
+ if __name__ == ' __main__' :
169
+ unittest.main()
160
170
```
161
171
162
172
As we see, to emit a stablehlo function into composite, first we make a python function
163
- representing the region of code that we want to call, (
164
- in this case ` F.scaled_dot_product_attention ` is already such function).
165
- Then we wrap the function with ` wrap_composite ` .
166
-
167
- NOTE: currently a model with ` wrap_composite ` call will not work with ` torch.export ` .
168
- We are actively working to make it work.
173
+ representing the region of code that we want to call, then, we register it
174
+ so that pytorch and jlibrary understands it's a custom region. Then, th
175
+ emitted Stablehlo will have ` mylib.scaled_dot_product_attention ` and ` mylib.softmax `
176
+ outlined stablehlo functions.
0 commit comments