Skip to content

Commit e795961

Browse files
committed
add tp too
1 parent bad6b82 commit e795961

1 file changed

Lines changed: 20 additions & 1 deletion

File tree

  • openequivariance/openequivariance/jax/jvp

openequivariance/openequivariance/jax/jvp/tp_prim.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash):
132132

133133
tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl)
134134
tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval)
135+
mlir.register_lowering(
136+
tp_fwd_jvp_p,
137+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
138+
platform="cuda",
139+
)
140+
mlir.register_lowering(
141+
tp_fwd_jvp_p,
142+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
143+
platform="rocm",
144+
)
135145

136146

137147
# ==============================================================================
@@ -225,7 +235,16 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash):
225235

226236
tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl)
227237
tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval)
228-
238+
mlir.register_lowering(
239+
tp_bwd_jvp_p,
240+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
241+
platform="cuda",
242+
)
243+
mlir.register_lowering(
244+
tp_bwd_jvp_p,
245+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
246+
platform="rocm",
247+
)
229248

230249
# ==============================================================================
231250
# 9. Transpose Rule for Backward JVP

0 commit comments

Comments
 (0)