@@ -302,10 +302,14 @@ defmodule Nx.Defn.Graph do
302302
303303 new_expr = put_in ( expr . data . args , args )
304304
305- # When we split, decide what to include in the stage and create parameter replacement
306305 { stage_expr , result_expr } =
307- case tensor_args do
308- [ ] ->
306+ case { tensor_args , expr . data . op , args } do
307+ { _ , :metadata , [ wrapped_expr , _ ] } when is_tuple ( wrapped_expr ) ->
308+ # We're effectively splitting on a tuple, so we need to create a
309+ # stage output for each element
310+ { wrapped_expr , new_expr }
311+
312+ { [ ] , _ , _ } ->
309313 # No intermediate computations - create a parameter for this split operation
310314 # The current expression will be computed in the next stage
311315 param = Expr . parameter ( new_expr , map_size ( state . args ) )
@@ -320,8 +324,30 @@ defmodule Nx.Defn.Graph do
320324
321325 # Update state with parameter mapping if we created one
322326 state =
323- case tensor_args do
324- [ ] ->
327+ case { tensor_args , expr . data . op , args } do
328+ { _ , :metadata , [ wrapped_expr , _ ] } when is_tuple ( wrapped_expr ) ->
329+ # Register each tuple element as a stage output and create a replacement parameter
330+ { state , _ } =
331+ wrapped_expr
332+ |> Tuple . to_list ( )
333+ |> Enum . reduce ( { state , 0 } , fn % T { } = elem_expr , { state , index } ->
334+ param = Expr . parameter ( elem_expr , index )
335+
336+ state = % {
337+ state
338+ | args:
339+ state . args
340+ |> Map . put ( elem_expr . data . id , { stage_id , index } )
341+ |> Map . put ( param . data . id , { stage_id , index } ) ,
342+ nodes_to_replace: Map . put ( state . nodes_to_replace , elem_expr . data . id , param )
343+ }
344+
345+ { state , index + 1 }
346+ end )
347+
348+ state
349+
350+ { [ ] , _ , _ } ->
325351 # Add parameter mapping and node replacement for the split operation
326352 # Extract the parameter from the tuple
327353 param = elem ( stage_expr , 0 )
@@ -355,9 +381,15 @@ defmodule Nx.Defn.Graph do
355381 { expr , { Map . put ( cache , id , expr ) , state } }
356382 end
357383
358- defp eval_apply ( :elem , % T { data: % Expr { id: id , args: [ tuple , i ] } } , { cache , state } ) do
359- { tuple , cache } = composite_eval ( tuple , state , cache )
360- res = elem ( tuple , i )
384+ defp eval_apply ( :elem , % T { data: % Expr { id: id , args: [ tuple , i ] } } = expr , { cache , state } ) do
385+ { tuple , { cache , state } } = composite_eval ( tuple , state , cache )
386+
387+ res =
388+ case tuple do
389+ t when is_tuple ( t ) -> elem ( t , i )
390+ % T { } -> put_in ( expr . data . args , [ tuple , i ] )
391+ end
392+
361393 { res , { Map . put ( cache , id , res ) , state } }
362394 end
363395
@@ -420,6 +452,42 @@ defmodule Nx.Defn.Graph do
420452 end
421453 end
422454
455+ defp rewrite_subtree (
456+ % T { data: % Expr { id: id , op: :elem , args: [ tuple_expr , index ] } } = expr ,
457+ state ,
458+ acc
459+ ) do
460+ case state . nodes_to_replace do
461+ % { ^ id => res } ->
462+ { res , put_in ( acc . used_args [ id ] , res ) }
463+
464+ _ ->
465+ { tuple_expr , acc } = rewrite_subtree ( tuple_expr , state , acc )
466+
467+ case tuple_expr do
468+ # Literal tuple: turn elem into a parameter for that element
469+ t when is_tuple ( t ) ->
470+ elem_expr = elem ( t , index )
471+ param = Expr . parameter ( elem_expr , index )
472+ { param , put_in ( acc . used_args [ elem_expr . data . id ] , param ) }
473+
474+ # Metadata-wrapped tuple: same as above
475+ % T { data: % Expr { op: :metadata , args: [ wrapped , _ ] } } when is_tuple ( wrapped ) ->
476+ elem_expr = elem ( wrapped , index )
477+ param = Expr . parameter ( elem_expr , index )
478+ { param , put_in ( acc . used_args [ elem_expr . data . id ] , param ) }
479+
480+ # Tuple tensor: create a parameter pointing to this index
481+ % T { type: { :tuple , _ } } ->
482+ param = Expr . parameter ( expr , index )
483+ { param , put_in ( acc . used_args [ param . data . id ] , param ) }
484+
485+ _ ->
486+ { put_in ( expr . data . args , [ tuple_expr , index ] ) , acc }
487+ end
488+ end
489+ end
490+
423491 defp rewrite_subtree ( % T { data: % Expr { id: id , args: args } } = expr , state , acc ) do
424492 case state . nodes_to_replace do
425493 % { ^ id => res } ->
0 commit comments