Skip to content

Commit 25300b7

Browse files
committed
fix grad
1 parent fc9c28c commit 25300b7

File tree

2 files changed

+4
-13
lines changed

2 files changed

+4
-13
lines changed

exla/lib/exla/mlir/value.ex

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -840,19 +840,6 @@ defmodule EXLA.MLIR.Value do
840840
|> one!()
841841
end
842842

843-
def custom_call(
844-
%Function{} = func,
845-
call_target_name,
846-
operands,
847-
out_typespecs,
848-
extra_attrs \\ []
849-
) do
850-
result_types = typespecs_to_mlir_types(out_typespecs)
851-
attributes = [call_target_name: attr_string(call_target_name), api_version: attr_i32(4), has_side_effect: attr_boolean(true)]
852-
attributes = attributes ++ extra_attrs
853-
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes)
854-
end
855-
856843
def get_typespec(value) do
857844
EXLA.NIF.mlir_get_typespec(value.ref)
858845
end

nx/lib/nx/defn/grad.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ defmodule Nx.Defn.Grad do
122122
acc
123123
end
124124

125+
defp parents_args(:elixir_call, _expr, _id, acc, _parent_vectorized_names) do
126+
acc
127+
end
128+
125129
defp parents_args(
126130
:optional,
127131
%{data: %{args: [call, _expr, callback]}} = t,

0 commit comments

Comments
 (0)