Skip to content

Commit 6dd8b49

Browse files
authored
feat(Nx.Defn.Graph): allow splitting based on an accumulator (#1624)
1 parent 7637677 commit 6dd8b49

File tree

2 files changed

+232
-29
lines changed

2 files changed

+232
-29
lines changed

nx/lib/nx/defn/graph.ex

Lines changed: 130 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ defmodule Nx.Defn.Graph do
3434
end
3535

3636
@doc """
37-
Splits the received Nx.Defn.Expr into stages given the rules.
37+
Splits the received Nx.Defn.Expr into stages based on each tensor.
3838
3939
`expr_split_fn` is a function that receives an `Nx.Tensor` containing an `Nx.Defn.Expr`
4040
and returns `true` when a split must happen, and `false` otherwise.
@@ -65,7 +65,30 @@ defmodule Nx.Defn.Graph do
6565
>
6666
"""
6767
def split(expr, expr_split_fn) when is_function(expr_split_fn, 1) do
68-
{chain, _, _} = __split__(expr, expr_split_fn)
68+
{chain, _, _} = __split__(expr, nil, fn node, acc -> {expr_split_fn.(node), acc} end)
69+
chain
70+
end
71+
72+
@doc """
73+
Splits the received Nx.Defn.Expr into stages based on each tensor and the accumulator.
74+
75+
`expr_split_fn` is a function that receives an `Nx.Tensor` and the accumulator,
76+
returning `{true, new_acc}` when a split must happen, and `{false, new_acc}`
77+
otherwise.
78+
79+
The decision to split is made based on the expression and the accumulator.
80+
This allows for more complex decisions to be made, such as splitting every 3 operations as in the example below.
81+
82+
# Count operations and split every 3 operations
83+
split_fn = fn _tensor, count ->
84+
new_count = count + 1
85+
{count > 0 and rem(new_count, 3) == 0, new_count}
86+
end
87+
88+
stages = Nx.Defn.Graph.split(expr, 0, split_fn)
89+
"""
90+
def split(expr, initial_acc, expr_split_fn) when is_function(expr_split_fn, 2) do
91+
{chain, _, _} = __split__(expr, initial_acc, expr_split_fn)
6992
chain
7093
end
7194

@@ -106,13 +129,14 @@ defmodule Nx.Defn.Graph do
106129
end
107130

108131
@doc false
109-
def __split__(expr, expr_split_fn) do
132+
def __split__(expr, initial_acc, expr_split_fn) do
110133
# state.expression_chain is a reverse accumulation of the stages and
111134
# snapshots of the state at each one so that we can properly remap parameters for each stage.
112135
state = %{
113136
expression_chain: [],
114137
nodes_to_replace: %{},
115138
expr_split_fn: expr_split_fn,
139+
split_acc: initial_acc,
116140
# args is a map of id -> {stage_id, output_container_position}
117141
args: %{}
118142
}
@@ -142,24 +166,38 @@ defmodule Nx.Defn.Graph do
142166
%{state | nodes_to_replace: nodes_to_replace}
143167
)
144168

145-
arg_remapping =
169+
{arg_remapping, _, _} =
146170
used_args
147171
|> Enum.sort_by(fn {_id, %T{data: %Expr{op: :parameter, args: [idx]}}} -> idx end)
148-
|> Enum.with_index(fn
149-
{id, expr}, idx ->
150-
{id, put_in(expr.data.args, [idx])}
172+
|> Enum.reduce({%{}, %{}, 0}, fn
173+
{id, expr}, {acc, sources, idx} ->
174+
# For replacement parameters, use the original parameter ID to find the source
175+
id = if Map.has_key?(state.args, expr.data.id), do: expr.data.id, else: id
176+
source = Map.fetch!(state.args, id)
177+
178+
if visited_expr = Map.get(sources, source) do
179+
{Map.put(acc, id, visited_expr), sources, idx}
180+
else
181+
expr = put_in(expr.data.args, [idx])
182+
{Map.put(acc, id, expr), Map.put(sources, source, expr), idx + 1}
183+
end
151184
end)
152-
|> Map.new()
153185

154186
{expr, _} =
155187
composite_rewrite_subtree(expr, %{state | nodes_to_replace: arg_remapping})
156188

189+
# Create arguments list from final remapping, preserving the deduplicated order
157190
arguments =
158191
arg_remapping
159-
|> Enum.map(fn {_id, arg_expr} ->
160-
id = arg_expr.data.id
192+
|> Enum.map(fn {original_id, arg_expr} ->
161193
[idx] = arg_expr.data.args
162-
source = Map.fetch!(state.args, id)
194+
# Use the same logic as above to find the correct source
195+
source_id =
196+
if Map.has_key?(state.args, arg_expr.data.id),
197+
do: arg_expr.data.id,
198+
else: original_id
199+
200+
source = Map.fetch!(state.args, source_id)
163201
{idx, %{source: source}}
164202
end)
165203
|> Enum.sort_by(fn {idx, _} -> idx end)
@@ -193,10 +231,29 @@ defmodule Nx.Defn.Graph do
193231
{res, {cache, state}}
194232

195233
_ ->
196-
if state.expr_split_fn.(ans) do
197-
split_expr(ans, {cache, state})
198-
else
199-
eval_apply(op, ans, {cache, state})
234+
case op do
235+
:parameter ->
236+
eval_apply(:parameter, ans, {cache, state})
237+
238+
:elem ->
239+
eval_apply(:elem, ans, {cache, state})
240+
241+
_ ->
242+
# First process the arguments with the current accumulator
243+
{args, {cache, state}} = Nx.Defn.Tree.apply_args(ans, {cache, state}, &eval/2)
244+
245+
# Then check if we should split based on this node
246+
{should_split?, new_acc} = state.expr_split_fn.(ans, state.split_acc)
247+
state = %{state | split_acc: new_acc}
248+
249+
if should_split? do
250+
# Use the already processed args for splitting
251+
split_expr_with_args(ans, args, {cache, state})
252+
else
253+
# Apply the operation with the processed args
254+
ans = put_in(ans.data.args, args)
255+
{ans, {Map.put(cache, ans.data.id, ans), state}}
256+
end
200257
end
201258
end
202259
end
@@ -205,8 +262,7 @@ defmodule Nx.Defn.Graph do
205262
{other, {cache, state}}
206263
end
207264

208-
defp split_expr(expr, {cache, state}) do
209-
{args, {cache, state}} = Nx.Defn.Tree.apply_args(expr, {cache, state}, &eval/2)
265+
defp split_expr_with_args(expr, args, {cache, state}) do
210266
# We need to save this so that each previous stage
211267
# isn't affected by following ones
212268
nodes_to_replace = state.nodes_to_replace
@@ -215,6 +271,20 @@ defmodule Nx.Defn.Graph do
215271

216272
{args, {tensor_args, _out_position, state}} =
217273
Enum.map_reduce(args, {[], 0, state}, fn
274+
%T{data: %Expr{op: :parameter}} = arg, {tensor_args, out_position, state} ->
275+
# Parameters are not computed values, so don't add them to tensor_args
276+
# Just update the state if needed
277+
state =
278+
case Map.has_key?(state.args, arg.data.id) do
279+
false ->
280+
%{state | args: Map.put(state.args, arg.data.id, {stage_id, out_position})}
281+
282+
true ->
283+
state
284+
end
285+
286+
{arg, {tensor_args, out_position, state}}
287+
218288
%T{} = expr, {tensor_args, out_position, state} ->
219289
arg = Expr.parameter(expr, map_size(state.args))
220290

@@ -232,18 +302,52 @@ defmodule Nx.Defn.Graph do
232302

233303
new_expr = put_in(expr.data.args, args)
234304

305+
# When we split, decide what to include in the stage and create parameter replacement
306+
{stage_expr, result_expr} =
307+
case tensor_args do
308+
[] ->
309+
# No intermediate computations - create a parameter for this split operation
310+
# The current expression will be computed in the next stage
311+
param = Expr.parameter(new_expr, map_size(state.args))
312+
{{param}, param}
313+
314+
_ ->
315+
# There are intermediate computations - only include those in the current stage
316+
# The current expression will be computed in the next stage using these outputs
317+
stage_expr = List.to_tuple(Enum.reverse(tensor_args))
318+
{stage_expr, new_expr}
319+
end
320+
321+
# Update state with parameter mapping if we created one
322+
state =
323+
case tensor_args do
324+
[] ->
325+
# Add parameter mapping and node replacement for the split operation
326+
# Extract the parameter from the tuple
327+
param = elem(stage_expr, 0)
328+
329+
%{
330+
state
331+
| args: Map.put(state.args, param.data.id, {stage_id, 0}),
332+
nodes_to_replace: Map.put(state.nodes_to_replace, new_expr.data.id, param)
333+
}
334+
335+
_ ->
336+
state
337+
end
338+
235339
state =
236340
update_in(
237341
state.expression_chain,
238342
&[
239-
{stage_id, List.to_tuple(Enum.reverse(tensor_args)), nodes_to_replace}
343+
{stage_id, stage_expr, nodes_to_replace}
240344
| &1
241345
]
242346
)
243347

244-
cache = Map.put(cache, new_expr.data.id, new_expr)
348+
cache = Map.put(cache, result_expr.data.id, result_expr)
245349

246-
{new_expr, {cache, state}}
350+
{result_expr, {cache, state}}
247351
end
248352

249353
defp eval_apply(:parameter, %T{data: %Expr{id: id, args: [idx]}} = expr, {cache, state}) do
@@ -285,9 +389,14 @@ defmodule Nx.Defn.Graph do
285389
defp rewrite_subtree(%T{data: %Expr{id: id, op: :parameter}} = expr, state, acc) do
286390
case state.nodes_to_replace do
287391
%{^id => res} ->
288-
{res, put_in(acc.used_args[id], res)}
392+
# This parameter is being replaced by a stage output - collect the replacement
393+
# We collect both the original id and the replacement id to ensure proper tracking
394+
acc = put_in(acc.used_args[id], res)
395+
acc = put_in(acc.used_args[res.data.id], res)
396+
{res, acc}
289397

290398
_ ->
399+
# This is an original parameter - collect it
291400
{expr, put_in(acc.used_args[id], expr)}
292401
end
293402
end

nx/test/nx/defn/graph_test.exs

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ defmodule Nx.Defn.GraphTest do
99

1010
doctest Nx.Defn.Graph
1111

12-
describe "traverse/1" do
12+
describe "split/2" do
1313
test "simple expression with 1 split and no common nodes" do
1414
expr =
1515
Nx.Defn.debug_expr(fn arg0, arg1 ->
@@ -21,11 +21,11 @@ defmodule Nx.Defn.GraphTest do
2121
end).(Nx.tensor([1, 2]), Nx.tensor([3, 4]))
2222

2323
split_fn = fn
24-
%T{data: %Expr{op: :dot}} -> true
25-
_ -> false
24+
%T{data: %Expr{op: :dot}}, acc -> {true, acc}
25+
_, acc -> {false, acc}
2626
end
2727

28-
{chain, cache, state} = Graph.__split__(expr, split_fn)
28+
{chain, cache, state} = Graph.__split__(expr, nil, split_fn)
2929

3030
assert [
3131
%Stage{
@@ -134,12 +134,12 @@ defmodule Nx.Defn.GraphTest do
134134
end).(Nx.tensor([[1, 2]]), Nx.tensor([[3], [4]]), Nx.tensor([5, 6]))
135135

136136
split_fn = fn
137-
%T{data: %Expr{op: :dot}} -> true
138-
%T{data: %Expr{op: :sum}} -> true
139-
_ -> false
137+
%T{data: %Expr{op: :dot}}, acc -> {true, acc}
138+
%T{data: %Expr{op: :sum}}, acc -> {true, acc}
139+
_, acc -> {false, acc}
140140
end
141141

142-
{chain, cache, state} = Graph.__split__(expr, split_fn)
142+
{chain, cache, state} = Graph.__split__(expr, nil, split_fn)
143143

144144
assert [
145145
%Stage{
@@ -356,6 +356,100 @@ defmodule Nx.Defn.GraphTest do
356356
end
357357
end
358358

359+
describe "split/3" do
360+
test "splits with accumulator" do
361+
expr =
362+
Nx.Defn.debug_expr(fn x0, x1, x2, x3, x4 ->
363+
x10 = Nx.add(x0, Nx.add(x1, x2))
364+
x20 = Nx.add(x10, Nx.add(x10, x3))
365+
x30 = Nx.add(x20, Nx.add(x20, x4))
366+
{x10, x20, x30}
367+
end).(1, 2, 3, 4, 5)
368+
369+
split_fn = fn
370+
_node, acc ->
371+
{acc > 0 and rem(acc, 2) == 0, acc + 1}
372+
end
373+
374+
chain = Graph.split(expr, 0, split_fn)
375+
376+
assert [stage_0, stage_1, stage_2] = chain
377+
378+
assert stage_0.arguments == [%{source: {nil, 0}}, %{source: {nil, 1}}, %{source: {nil, 2}}]
379+
380+
assert {
381+
%T{
382+
data: %Expr{
383+
op: :add,
384+
args: [
385+
%T{data: %Expr{op: :parameter, args: [0]}},
386+
%T{
387+
data: %Expr{
388+
args: [
389+
%T{data: %Expr{op: :parameter, args: [1]}},
390+
%T{data: %Expr{op: :parameter, args: [2]}}
391+
],
392+
op: :add
393+
}
394+
}
395+
]
396+
}
397+
}
398+
} = stage_0.expr
399+
400+
assert stage_1.arguments == [
401+
%{source: {nil, 3}},
402+
%{source: {stage_0.id, 0}}
403+
]
404+
405+
assert {
406+
%T{
407+
data: %Expr{
408+
op: :add,
409+
args: [
410+
%T{data: %Expr{op: :parameter, args: [1]}},
411+
%T{
412+
data: %Expr{
413+
args: [
414+
%T{data: %Expr{op: :parameter, args: [1]}},
415+
%T{data: %Expr{op: :parameter, args: [0]}}
416+
],
417+
op: :add
418+
}
419+
}
420+
]
421+
}
422+
}
423+
} = stage_1.expr
424+
425+
assert stage_2.arguments == [
426+
%{source: {nil, 4}},
427+
%{source: {stage_0.id, 0}},
428+
%{source: {stage_1.id, 0}}
429+
]
430+
431+
assert {%T{data: %Expr{op: :parameter, args: [1]}},
432+
%T{data: %Expr{op: :parameter, args: [2]}},
433+
%T{
434+
data: %Expr{
435+
op: :add,
436+
args: [
437+
%T{data: %Expr{op: :parameter, args: [2]}},
438+
%T{
439+
data: %Expr{
440+
args: [
441+
%T{data: %Expr{op: :parameter, args: [2]}},
442+
%T{data: %Expr{op: :parameter, args: [0]}}
443+
],
444+
op: :add
445+
}
446+
}
447+
]
448+
}
449+
}} = stage_2.expr
450+
end
451+
end
452+
359453
describe "run/2" do
360454
test "executes the stages chain and returns the correct result" do
361455
function = fn arg0, arg1 ->

0 commit comments

Comments
 (0)