@@ -854,11 +854,11 @@ defmodule Nx.Defn do
854854 end
855855
856856 def shard_jit_apply ( fun , mesh , args , opts \\ [ ] )
857- when is_function ( fun ) and is_list ( args ) and is_list ( opts ) do
858- { on_conflict , opts } = Keyword . pop ( opts , :on_conflict , :raise )
857+ when is_function ( fun ) and is_list ( args ) and is_list ( opts ) do
858+ { on_conflict , opts } = Keyword . pop ( opts , :on_conflict , :raise )
859859
860860 cond do
861- Nx.Defn . current ( ) == nil ->
861+ Nx.Defn.Compiler . current ( ) == nil ->
862862 do_shard_jit_apply ( fun , mesh , args , opts )
863863
864864 on_conflict == :raise ->
@@ -870,7 +870,14 @@ defmodule Nx.Defn do
870870 on_conflict == :reuse ->
871871 apply ( fun , args )
872872 end
873- end
873+ end
874+
875+ defp do_shard_jit_apply ( fun , mesh , args , opts ) do
876+ opts = prepare_options ( opts )
877+ { fun , params , _templates , flatten } = Nx.Defn.Compiler . to_lazy_params ( fun , args )
878+ [ res ] = Nx.Defn.Compiler . __shard_jit__ ( fun , mesh , params , [ flatten ] , opts )
879+ res
880+ end
874881
875882 defp compile_error! ( env , description ) do
876883 raise CompileError , line: env . line , file: env . file , description: description
0 commit comments