Skip to content

Commit af944e4

Browse files
authored
feat(exla): take advantage of the new LU impl (#1590)
1 parent 522eb27 commit af944e4

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

exla/lib/exla/defn.ex

+8-14
Original file line numberDiff line numberDiff line change
@@ -546,21 +546,15 @@ defmodule EXLA.Defn do
546546

547547
defp cached_recur_operator(
548548
:lu,
549-
%T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}},
550-
state,
549+
%T{
550+
data: %Expr{args: [{p_expr, l_expr, u_expr}, %{type: {type_kind, _}} = tensor, _opts]}
551+
},
552+
%{client: %{platform: :host}} = state,
551553
cache
552-
) do
553-
%{type: {p_type_kind, _}} = p_expr
554-
%{type: {out_type_kind, _}} = l_expr
555-
556-
if state.client.platform != :host do
557-
raise ArgumentError, "XLA does not currently support the LU operation on non-host devices"
558-
end
559-
560-
if p_type_kind == :c or out_type_kind == :c do
561-
raise ArgumentError, "XLA does not currently support the LU operation for complex inputs"
562-
end
563-
554+
)
555+
when type_kind != :c do
556+
# We only want to accelerate the LU operation for real inputs on the host device.
557+
# Otherwise, we use the default implementation in Nx.
564558
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
565559

566560
tensor =

0 commit comments

Comments
 (0)