diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index fbf392862f..3a676b5942 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -546,21 +546,15 @@ defmodule EXLA.Defn do defp cached_recur_operator( :lu, - %T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}}, - state, + %T{ + data: %Expr{args: [{p_expr, l_expr, u_expr}, %{type: {type_kind, _}} = tensor, _opts]} + }, + %{client: %{platform: :host}} = state, cache - ) do - %{type: {p_type_kind, _}} = p_expr - %{type: {out_type_kind, _}} = l_expr - - if state.client.platform != :host do - raise ArgumentError, "XLA does not currently support the LU operation on non-host devices" - end - - if p_type_kind == :c or out_type_kind == :c do - raise ArgumentError, "XLA does not currently support the LU operation for complex inputs" - end - + ) + when type_kind != :c do + # We only want to accelerate the LU operation for real inputs on the host device. + # Otherwise, we use the default implementation in Nx. {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() tensor =