Skip to content

Commit e97afd0

Browse files
authored
fix: support splitting on tuples directly (#1626)
1 parent cd95fd7 commit e97afd0

File tree

2 files changed

+111
-8
lines changed

2 files changed

+111
-8
lines changed

nx/lib/nx/defn/graph.ex

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,14 @@ defmodule Nx.Defn.Graph do
302302

303303
new_expr = put_in(expr.data.args, args)
304304

305-
# When we split, decide what to include in the stage and create parameter replacement
306305
{stage_expr, result_expr} =
307-
case tensor_args do
308-
[] ->
306+
case {tensor_args, expr.data.op, args} do
307+
{_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) ->
308+
# We're effectively splitting on a tuple, so we need to create a
309+
# stage output for each element
310+
{wrapped_expr, new_expr}
311+
312+
{[], _, _} ->
309313
# No intermediate computations - create a parameter for this split operation
310314
# The current expression will be computed in the next stage
311315
param = Expr.parameter(new_expr, map_size(state.args))
@@ -320,8 +324,30 @@ defmodule Nx.Defn.Graph do
320324

321325
# Update state with parameter mapping if we created one
322326
state =
323-
case tensor_args do
324-
[] ->
327+
case {tensor_args, expr.data.op, args} do
328+
{_, :metadata, [wrapped_expr, _]} when is_tuple(wrapped_expr) ->
329+
# Register each tuple element as a stage output and create a replacement parameter
330+
{state, _} =
331+
wrapped_expr
332+
|> Tuple.to_list()
333+
|> Enum.reduce({state, 0}, fn %T{} = elem_expr, {state, index} ->
334+
param = Expr.parameter(elem_expr, index)
335+
336+
state = %{
337+
state
338+
| args:
339+
state.args
340+
|> Map.put(elem_expr.data.id, {stage_id, index})
341+
|> Map.put(param.data.id, {stage_id, index}),
342+
nodes_to_replace: Map.put(state.nodes_to_replace, elem_expr.data.id, param)
343+
}
344+
345+
{state, index + 1}
346+
end)
347+
348+
state
349+
350+
{[], _, _} ->
325351
# Add parameter mapping and node replacement for the split operation
326352
# Extract the parameter from the tuple
327353
param = elem(stage_expr, 0)
@@ -355,9 +381,15 @@ defmodule Nx.Defn.Graph do
355381
{expr, {Map.put(cache, id, expr), state}}
356382
end
357383

358-
defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}}, {cache, state}) do
359-
{tuple, cache} = composite_eval(tuple, state, cache)
360-
res = elem(tuple, i)
384+
defp eval_apply(:elem, %T{data: %Expr{id: id, args: [tuple, i]}} = expr, {cache, state}) do
385+
{tuple, {cache, state}} = composite_eval(tuple, state, cache)
386+
387+
res =
388+
case tuple do
389+
t when is_tuple(t) -> elem(t, i)
390+
%T{} -> put_in(expr.data.args, [tuple, i])
391+
end
392+
361393
{res, {Map.put(cache, id, res), state}}
362394
end
363395

@@ -420,6 +452,42 @@ defmodule Nx.Defn.Graph do
420452
end
421453
end
422454

455+
defp rewrite_subtree(
456+
%T{data: %Expr{id: id, op: :elem, args: [tuple_expr, index]}} = expr,
457+
state,
458+
acc
459+
) do
460+
case state.nodes_to_replace do
461+
%{^id => res} ->
462+
{res, put_in(acc.used_args[id], res)}
463+
464+
_ ->
465+
{tuple_expr, acc} = rewrite_subtree(tuple_expr, state, acc)
466+
467+
case tuple_expr do
468+
# Literal tuple: turn elem into a parameter for that element
469+
t when is_tuple(t) ->
470+
elem_expr = elem(t, index)
471+
param = Expr.parameter(elem_expr, index)
472+
{param, put_in(acc.used_args[elem_expr.data.id], param)}
473+
474+
# Metadata-wrapped tuple: same as above
475+
%T{data: %Expr{op: :metadata, args: [wrapped, _]}} when is_tuple(wrapped) ->
476+
elem_expr = elem(wrapped, index)
477+
param = Expr.parameter(elem_expr, index)
478+
{param, put_in(acc.used_args[elem_expr.data.id], param)}
479+
480+
# Tuple tensor: create a parameter pointing to this index
481+
%T{type: {:tuple, _}} ->
482+
param = Expr.parameter(expr, index)
483+
{param, put_in(acc.used_args[param.data.id], param)}
484+
485+
_ ->
486+
{put_in(expr.data.args, [tuple_expr, index]), acc}
487+
end
488+
end
489+
end
490+
423491
defp rewrite_subtree(%T{data: %Expr{id: id, args: args}} = expr, state, acc) do
424492
case state.nodes_to_replace do
425493
%{^id => res} ->

nx/test/nx/defn/graph_test.exs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,41 @@ defmodule Nx.Defn.GraphTest do
354354
assert %T{data: %Expr{op: :sum, args: [a, [axes: [1], keep_axes: false]]}} = left
355355
assert %T{data: %Expr{op: :parameter, args: [1]}} = a
356356
end
357+
358+
test "supports splitting on tuples with metadata" do
359+
expr =
360+
Nx.Defn.debug_expr(fn x ->
361+
y = Nx.add(x, 1)
362+
z = Nx.add(x, 2)
363+
w = {Nx.add(y, 3), Nx.add(z, 4)}
364+
{a, b} = Nx.Defn.Expr.metadata(w, %{split: true})
365+
Nx.add(a, b)
366+
end).(Nx.tensor([1, 2, 3]))
367+
368+
split_fn = fn
369+
%T{data: %Expr{op: :metadata, args: [_expr, %{split: true}]}} -> true
370+
_ -> false
371+
end
372+
373+
assert [%Stage{} = stage_0, %Stage{} = stage_1] = Graph.split(expr, split_fn)
374+
375+
assert [%{source: {nil, 0}}] = stage_0.arguments
376+
assert {add_y, add_z} = stage_0.expr
377+
378+
assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [4]}}, y]}} =
379+
add_y
380+
381+
assert %T{data: %Expr{op: :parameter, args: [0]}} = y
382+
383+
assert %T{data: %Expr{op: :add, args: [%T{data: %Expr{op: :constant, args: [6]}}, ^y]}} =
384+
add_z
385+
386+
assert stage_1.arguments == [%{source: {stage_0.id, 0}}, %{source: {stage_0.id, 1}}]
387+
assert %T{data: %Expr{op: :add, args: [add_y, add_z]}} = stage_1.expr
388+
389+
assert %T{data: %Expr{op: :parameter, args: [0]}} = add_y
390+
assert %T{data: %Expr{op: :parameter, args: [1]}} = add_z
391+
end
357392
end
358393

359394
describe "split/3" do

0 commit comments

Comments
 (0)