Skip to content

Commit 7591682

Browse files
authored
fix: added do_shard_jit_apply/4 for __shard_jit__ (#1650)
1 parent 277c12f commit 7591682

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

nx/lib/nx/defn.ex

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -854,11 +854,11 @@ defmodule Nx.Defn do
854854
end
855855

856856
def shard_jit_apply(fun, mesh, args, opts \\ [])
857-
when is_function(fun) and is_list(args) and is_list(opts) do
858-
{on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise)
857+
when is_function(fun) and is_list(args) and is_list(opts) do
858+
{on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise)
859859

860860
cond do
861-
Nx.Defn.current() == nil ->
861+
Nx.Defn.Compiler.current() == nil ->
862862
do_shard_jit_apply(fun, mesh, args, opts)
863863

864864
on_conflict == :raise ->
@@ -870,7 +870,14 @@ defmodule Nx.Defn do
870870
on_conflict == :reuse ->
871871
apply(fun, args)
872872
end
873-
end
873+
end
874+
875+
defp do_shard_jit_apply(fun, mesh, args, opts) do
876+
opts = prepare_options(opts)
877+
{fun, params, _templates, flatten} = Nx.Defn.Compiler.to_lazy_params(fun, args)
878+
[res] = Nx.Defn.Compiler.__shard_jit__(fun, mesh, params, [flatten], opts)
879+
res
880+
end
874881

875882
defp compile_error!(env, description) do
876883
raise CompileError, line: env.line, file: env.file, description: description

nx/lib/nx/defn/compiler.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ defmodule Nx.Defn.Compiler do
293293
end
294294

295295
def __shard_jit__(fun, mesh, params, args_list, opts) do
296-
{module, runtime_fun, opts} = prepare_options(fun, mesh, opts)
296+
{module, runtime_fun, opts} = prepare_options(fun, opts)
297297
module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts)
298298
rescue
299299
e in [UndefinedFunctionError] ->

0 commit comments

Comments
 (0)