@@ -34,7 +34,7 @@ defmodule Nx.Defn.Graph do
3434 end
3535
3636 @ doc """
37- Splits the received Nx.Defn.Expr into stages given the rules .
37+ Splits the received Nx.Defn.Expr into stages based on each tensor .
3838
3939 `expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
4040 and returns `true` when a split must happen, and `false` otherwise.
@@ -65,7 +65,30 @@ defmodule Nx.Defn.Graph do
6565 >
6666 """
6767 def split ( expr , expr_split_fn ) when is_function ( expr_split_fn , 1 ) do
68- { chain , _ , _ } = __split__ ( expr , expr_split_fn )
68+ { chain , _ , _ } = __split__ ( expr , nil , fn node , acc -> { expr_split_fn . ( node ) , acc } end )
69+ chain
70+ end
71+
72+ @ doc """
73+ Splits the received Nx.Defn.Expr into stages based on each tensor and the accumulator.
74+
75+ `expr_split_fn` is a function that receives an `Nx.Tensor` and the accumulator,
76+ returning `{true, new_acc}` when a split must happen, and `{false, new_acc}`
77+ otherwise.
78+
79+ The decision to split is made based on the expression and the accumulator.
80+ This allows for more complex decisions to be made, such as splitting every 3 operations as in the example below.
81+
82+ # Count operations and split every 3 operations
83+ split_fn = fn _tensor, count ->
84+ new_count = count + 1
85+ {count > 0 and rem(new_count, 3) == 0, new_count}
86+ end
87+
88+ stages = Nx.Defn.Graph.split(expr, 0, split_fn)
89+ """
90+ def split ( expr , initial_acc , expr_split_fn ) when is_function ( expr_split_fn , 2 ) do
91+ { chain , _ , _ } = __split__ ( expr , initial_acc , expr_split_fn )
6992 chain
7093 end
7194
@@ -106,13 +129,14 @@ defmodule Nx.Defn.Graph do
106129 end
107130
108131 @ doc false
109- def __split__ ( expr , expr_split_fn ) do
132+ def __split__ ( expr , initial_acc , expr_split_fn ) do
110133 # state.expression_chain is a reverse accumulation of the stages and
111134 # snapshots of the state at each one so that we can properly remap parameters for each stage.
112135 state = % {
113136 expression_chain: [ ] ,
114137 nodes_to_replace: % { } ,
115138 expr_split_fn: expr_split_fn ,
139+ split_acc: initial_acc ,
116140 # args is a map of id -> {stage_id, output_container_position}
117141 args: % { }
118142 }
@@ -142,24 +166,38 @@ defmodule Nx.Defn.Graph do
142166 % { state | nodes_to_replace: nodes_to_replace }
143167 )
144168
145- arg_remapping =
169+ { arg_remapping , _ , _ } =
146170 used_args
147171 |> Enum . sort_by ( fn { _id , % T { data: % Expr { op: :parameter , args: [ idx ] } } } -> idx end )
148- |> Enum . with_index ( fn
149- { id , expr } , idx ->
150- { id , put_in ( expr . data . args , [ idx ] ) }
172+ |> Enum . reduce ( { % { } , % { } , 0 } , fn
173+ { id , expr } , { acc , sources , idx } ->
174+ # For replacement parameters, use the original parameter ID to find the source
175+ id = if Map . has_key? ( state . args , expr . data . id ) , do: expr . data . id , else: id
176+ source = Map . fetch! ( state . args , id )
177+
178+ if visited_expr = Map . get ( sources , source ) do
179+ { Map . put ( acc , id , visited_expr ) , sources , idx }
180+ else
181+ expr = put_in ( expr . data . args , [ idx ] )
182+ { Map . put ( acc , id , expr ) , Map . put ( sources , source , expr ) , idx + 1 }
183+ end
151184 end )
152- |> Map . new ( )
153185
154186 { expr , _ } =
155187 composite_rewrite_subtree ( expr , % { state | nodes_to_replace: arg_remapping } )
156188
189+ # Create arguments list from final remapping, preserving the deduplicated order
157190 arguments =
158191 arg_remapping
159- |> Enum . map ( fn { _id , arg_expr } ->
160- id = arg_expr . data . id
192+ |> Enum . map ( fn { original_id , arg_expr } ->
161193 [ idx ] = arg_expr . data . args
162- source = Map . fetch! ( state . args , id )
194+ # Use the same logic as above to find the correct source
195+ source_id =
196+ if Map . has_key? ( state . args , arg_expr . data . id ) ,
197+ do: arg_expr . data . id ,
198+ else: original_id
199+
200+ source = Map . fetch! ( state . args , source_id )
163201 { idx , % { source: source } }
164202 end )
165203 |> Enum . sort_by ( fn { idx , _ } -> idx end )
@@ -193,10 +231,29 @@ defmodule Nx.Defn.Graph do
193231 { res , { cache , state } }
194232
195233 _ ->
196- if state . expr_split_fn . ( ans ) do
197- split_expr ( ans , { cache , state } )
198- else
199- eval_apply ( op , ans , { cache , state } )
234+ case op do
235+ :parameter ->
236+ eval_apply ( :parameter , ans , { cache , state } )
237+
238+ :elem ->
239+ eval_apply ( :elem , ans , { cache , state } )
240+
241+ _ ->
242+ # First process the arguments with the current accumulator
243+ { args , { cache , state } } = Nx.Defn.Tree . apply_args ( ans , { cache , state } , & eval / 2 )
244+
245+ # Then check if we should split based on this node
246+ { should_split? , new_acc } = state . expr_split_fn . ( ans , state . split_acc )
247+ state = % { state | split_acc: new_acc }
248+
249+ if should_split? do
250+ # Use the already processed args for splitting
251+ split_expr_with_args ( ans , args , { cache , state } )
252+ else
253+ # Apply the operation with the processed args
254+ ans = put_in ( ans . data . args , args )
255+ { ans , { Map . put ( cache , ans . data . id , ans ) , state } }
256+ end
200257 end
201258 end
202259 end
@@ -205,8 +262,7 @@ defmodule Nx.Defn.Graph do
205262 { other , { cache , state } }
206263 end
207264
208- defp split_expr ( expr , { cache , state } ) do
209- { args , { cache , state } } = Nx.Defn.Tree . apply_args ( expr , { cache , state } , & eval / 2 )
265+ defp split_expr_with_args ( expr , args , { cache , state } ) do
210266 # We need to save this so that each previous stage
211267 # isn't affected by following ones
212268 nodes_to_replace = state . nodes_to_replace
@@ -215,6 +271,20 @@ defmodule Nx.Defn.Graph do
215271
216272 { args , { tensor_args , _out_position , state } } =
217273 Enum . map_reduce ( args , { [ ] , 0 , state } , fn
274+ % T { data: % Expr { op: :parameter } } = arg , { tensor_args , out_position , state } ->
275+ # Parameters are not computed values, so don't add them to tensor_args
276+ # Just update the state if needed
277+ state =
278+ case Map . has_key? ( state . args , arg . data . id ) do
279+ false ->
280+ % { state | args: Map . put ( state . args , arg . data . id , { stage_id , out_position } ) }
281+
282+ true ->
283+ state
284+ end
285+
286+ { arg , { tensor_args , out_position , state } }
287+
218288 % T { } = expr , { tensor_args , out_position , state } ->
219289 arg = Expr . parameter ( expr , map_size ( state . args ) )
220290
@@ -232,18 +302,52 @@ defmodule Nx.Defn.Graph do
232302
233303 new_expr = put_in ( expr . data . args , args )
234304
305+ # When we split, decide what to include in the stage and create parameter replacement
306+ { stage_expr , result_expr } =
307+ case tensor_args do
308+ [ ] ->
309+ # No intermediate computations - create a parameter for this split operation
310+ # The current expression will be computed in the next stage
311+ param = Expr . parameter ( new_expr , map_size ( state . args ) )
312+ { { param } , param }
313+
314+ _ ->
315+ # There are intermediate computations - only include those in the current stage
316+ # The current expression will be computed in the next stage using these outputs
317+ stage_expr = List . to_tuple ( Enum . reverse ( tensor_args ) )
318+ { stage_expr , new_expr }
319+ end
320+
321+ # Update state with parameter mapping if we created one
322+ state =
323+ case tensor_args do
324+ [ ] ->
325+ # Add parameter mapping and node replacement for the split operation
326+ # Extract the parameter from the tuple
327+ param = elem ( stage_expr , 0 )
328+
329+ % {
330+ state
331+ | args: Map . put ( state . args , param . data . id , { stage_id , 0 } ) ,
332+ nodes_to_replace: Map . put ( state . nodes_to_replace , new_expr . data . id , param )
333+ }
334+
335+ _ ->
336+ state
337+ end
338+
235339 state =
236340 update_in (
237341 state . expression_chain ,
238342 & [
239- { stage_id , List . to_tuple ( Enum . reverse ( tensor_args ) ) , nodes_to_replace }
343+ { stage_id , stage_expr , nodes_to_replace }
240344 | & 1
241345 ]
242346 )
243347
244- cache = Map . put ( cache , new_expr . data . id , new_expr )
348+ cache = Map . put ( cache , result_expr . data . id , result_expr )
245349
246- { new_expr , { cache , state } }
350+ { result_expr , { cache , state } }
247351 end
248352
249353 defp eval_apply ( :parameter , % T { data: % Expr { id: id , args: [ idx ] } } = expr , { cache , state } ) do
@@ -285,9 +389,14 @@ defmodule Nx.Defn.Graph do
285389 defp rewrite_subtree ( % T { data: % Expr { id: id , op: :parameter } } = expr , state , acc ) do
286390 case state . nodes_to_replace do
287391 % { ^ id => res } ->
288- { res , put_in ( acc . used_args [ id ] , res ) }
392+ # This parameter is being replaced by a stage output - collect the replacement
393+ # We collect both the original id and the replacement id to ensure proper tracking
394+ acc = put_in ( acc . used_args [ id ] , res )
395+ acc = put_in ( acc . used_args [ res . data . id ] , res )
396+ { res , acc }
289397
290398 _ ->
399+ # This is an original parameter - collect it
291400 { expr , put_in ( acc . used_args [ id ] , expr ) }
292401 end
293402 end
0 commit comments