Skip to content

Commit 5f22d83

Browse files
cigraingerclaude
andcommitted
feat: shuffle join routing for large right sides in distributed joins
When the right side of a distributed join exceeds the broadcast threshold (256MB default), the Coordinator now automatically splits the pipeline into stages and delegates to Shuffle.execute/3: 1. Execute pre-join ops distributed → left result 2. Shuffle join left result with right side 3. Apply post-join ops on shuffle result Also adds :broadcast_threshold option to Coordinator.execute for testing. 4 new shuffle routing tests using broadcast_threshold: 0 to force the shuffle path on small data. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5ad1576 commit 5f22d83

2 files changed

Lines changed: 260 additions & 73 deletions

File tree

lib/dux/remote/coordinator.ex

Lines changed: 136 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ defmodule Dux.Remote.Coordinator do
1919
The result is a `%Dux{}` struct with the merged data.
2020
"""
2121

22-
alias Dux.Remote.{Merger, Partitioner, PipelineSplitter, Worker}
22+
alias Dux.Remote.{Merger, Partitioner, PipelineSplitter, Shuffle, Worker}
2323
import Dux.SQL.Helpers, only: [qi: 1]
2424

2525
# Broadcast threshold: 256MB serialized Arrow IPC
@@ -43,6 +43,7 @@ defmodule Dux.Remote.Coordinator do
4343
workers = Keyword.get_lazy(opts, :workers, &Worker.list/0)
4444
timeout = Keyword.get(opts, :timeout, :infinity)
4545
strategy = Keyword.get(opts, :strategy, :round_robin)
46+
bcast_threshold = Keyword.get(opts, :broadcast_threshold, @broadcast_threshold)
4647

4748
if workers == [] do
4849
raise ArgumentError, "no workers available for distributed execution"
@@ -53,36 +54,29 @@ defmodule Dux.Remote.Coordinator do
5354
PipelineSplitter.split(pipeline.ops)
5455

5556
# Preprocess joins: broadcast/shuffle right sides that aren't worker-safe
56-
{processed_ops, broadcast_tables} = preprocess_joins(worker_ops, workers, timeout)
57+
case preprocess_joins(worker_ops, workers, timeout, bcast_threshold) do
58+
{:ok, processed_ops, broadcast_tables} ->
59+
# All joins handled inline (broadcast or push-down)
60+
worker_pipeline = %{pipeline | ops: processed_ops}
5761

58-
worker_pipeline = %{pipeline | ops: processed_ops}
59-
60-
try do
61-
# Partition the worker pipeline across workers
62-
assignments = Partitioner.assign(worker_pipeline, workers, strategy: strategy)
63-
64-
# Fan out: each worker executes its partition
65-
results = fan_out(assignments, timeout)
66-
67-
# Collect successful results, handle failures
68-
{successes, failures} = partition_results(results)
69-
70-
if successes == [] do
71-
reasons = Enum.map(failures, fn {:error, reason} -> reason end)
72-
raise ArgumentError, "all workers failed: #{inspect(reasons)}"
73-
end
74-
75-
# Merge partial results on coordinator
76-
merged = Merger.merge_to_dux(successes, worker_pipeline)
77-
78-
# Apply AVG rewrites if any
79-
merged = apply_avg_rewrites(merged, rewrites)
62+
try do
63+
result = execute_fan_out(worker_pipeline, workers, strategy, timeout)
64+
result = apply_avg_rewrites(result, rewrites)
65+
apply_coordinator_ops(result, coord_ops)
66+
after
67+
cleanup_broadcast_tables(workers, broadcast_tables)
68+
end
8069

81-
# Apply coordinator-only ops (slice, pivot, etc.)
82-
apply_coordinator_ops(merged, coord_ops)
83-
after
84-
# Clean up broadcast tables on all workers
85-
cleanup_broadcast_tables(workers, broadcast_tables)
70+
{:shuffle, ops_before, {right_computed, how, on_cols, suffix}, ops_after, broadcast_tables} ->
71+
# Pipeline needs a shuffle stage: execute pre-join → shuffle → post-join
72+
try do
73+
execute_with_shuffle(
74+
pipeline, ops_before, right_computed, how, on_cols, suffix,
75+
ops_after, coord_ops, rewrites, workers, strategy, timeout
76+
)
77+
after
78+
cleanup_broadcast_tables(workers, broadcast_tables)
79+
end
8680
end
8781
end
8882

@@ -103,6 +97,64 @@ defmodule Dux.Remote.Coordinator do
10397
# Internal
10498
# ---------------------------------------------------------------------------
10599

100+
# Standard distributed execution: partition → fan out → merge
101+
defp execute_fan_out(worker_pipeline, workers, strategy, timeout) do
102+
assignments = Partitioner.assign(worker_pipeline, workers, strategy: strategy)
103+
results = fan_out(assignments, timeout)
104+
{successes, failures} = partition_results(results)
105+
106+
if successes == [] do
107+
reasons = Enum.map(failures, fn {:error, reason} -> reason end)
108+
raise ArgumentError, "all workers failed: #{inspect(reasons)}"
109+
end
110+
111+
Merger.merge_to_dux(successes, worker_pipeline)
112+
end
113+
114+
# Multi-stage execution for shuffle joins:
115+
# 1. Execute ops before the join as a distributed query
116+
# 2. Shuffle join the result with the right side
117+
# 3. Apply remaining ops + coordinator ops
118+
defp execute_with_shuffle(
119+
pipeline, ops_before, right_computed, how, on_cols, _suffix,
120+
ops_after, coord_ops, rewrites, workers, strategy, timeout
121+
) do
122+
# Stage 1: execute pre-join ops distributed
123+
left_result =
124+
if ops_before == [] do
125+
# No ops before join — just compute the source
126+
Dux.compute(%{pipeline | ops: []})
127+
else
128+
pre_pipeline = %{pipeline | ops: ops_before}
129+
execute_fan_out(pre_pipeline, workers, strategy, timeout)
130+
end
131+
132+
# Stage 2: shuffle join
133+
# Extract the join column for shuffle (uses first column pair)
134+
{left_col, _right_col} = hd(on_cols)
135+
136+
shuffle_result =
137+
Shuffle.execute(left_result, right_computed,
138+
on: String.to_atom(left_col),
139+
how: how,
140+
workers: workers,
141+
timeout: timeout
142+
)
143+
144+
# Stage 3: apply remaining ops + rewrites + coordinator ops
145+
# Any ops after the join in the worker list need to run on the shuffle result
146+
result =
147+
if ops_after == [] do
148+
shuffle_result
149+
else
150+
post_pipeline = %{shuffle_result | ops: ops_after}
151+
Dux.compute(post_pipeline)
152+
end
153+
154+
result = apply_avg_rewrites(result, rewrites)
155+
apply_coordinator_ops(result, coord_ops)
156+
end
157+
106158
defp fan_out(assignments, timeout) do
107159
# Execute in parallel via Task.async_stream
108160
assignments
@@ -136,56 +188,69 @@ defmodule Dux.Remote.Coordinator do
136188
# ---------------------------------------------------------------------------
137189

138190
# Scan worker ops for joins where the right side holds a source that
139-
# workers can't resolve (e.g. a local {:table, ref}). For each such join,
140-
# compute the right side on the coordinator, broadcast it to all workers,
141-
# and replace the join op with a join against the broadcast table.
191+
# workers can't resolve (e.g. a local {:table, ref}). For each such join:
142192
#
143-
# Returns {processed_ops, broadcast_table_names} where broadcast_table_names
144-
# is a list of names to clean up after query execution.
145-
defp preprocess_joins(ops, workers, timeout) do
146-
Enum.map_reduce(ops, [], fn
147-
{:join, %Dux{} = right, how, on_cols, suffix}, broadcast_names ->
148-
if worker_safe_source?(right.source) and Enum.all?(right.ops, &worker_safe_op?/1) do
149-
# Right side is resolvable on workers — pass through unchanged
150-
{{:join, right, how, on_cols, suffix}, broadcast_names}
151-
else
152-
# Right side has a local source — need to broadcast or shuffle
153-
route_join(right, how, on_cols, suffix, workers, timeout, broadcast_names)
154-
end
193+
# - Small right side → broadcast to workers, replace join with broadcast ref
194+
# - Large right side → signal a pipeline split for shuffle join
195+
#
196+
# Returns:
197+
# {:ok, processed_ops, broadcast_names} — all joins handled inline
198+
# {:shuffle, ops_before, join_info, ops_after, broadcast_names} — need a shuffle stage
199+
defp preprocess_joins(ops, workers, timeout, threshold) do
200+
do_preprocess(ops, [], [], workers, timeout, threshold)
201+
end
155202

156-
op, broadcast_names ->
157-
{op, broadcast_names}
158-
end)
203+
# Base case: all ops processed, no shuffle needed
204+
defp do_preprocess([], processed, broadcast_names, _workers, _timeout, _threshold) do
205+
{:ok, Enum.reverse(processed), broadcast_names}
159206
end
160207

161-
# Route a join with a non-worker-safe right side to broadcast or shuffle.
162-
defp route_join(right, how, on_cols, suffix, workers, timeout, broadcast_names) do
163-
# Compute the right side on the coordinator
164-
right_computed = Dux.compute(right)
165-
{:table, right_ref} = right_computed.source
166-
right_ipc = Dux.Native.table_to_ipc(right_ref)
167-
168-
if byte_size(right_ipc) <= @broadcast_threshold do
169-
# Small enough to broadcast
170-
broadcast_name = "__bcast_#{:erlang.unique_integer([:positive])}"
171-
broadcast_to_workers(workers, broadcast_name, right_ipc, timeout)
172-
173-
# Replace right side with a query against the broadcast table
174-
broadcast_right = Dux.from_query("SELECT * FROM #{qi(broadcast_name)}")
175-
{{:join, broadcast_right, how, on_cols, suffix}, [broadcast_name | broadcast_names]}
208+
# Join op: check if right side is worker-safe
209+
defp do_preprocess(
210+
[{:join, %Dux{} = right, how, on_cols, suffix} = op | rest],
211+
processed,
212+
broadcast_names,
213+
workers,
214+
timeout,
215+
threshold
216+
) do
217+
if worker_safe_source?(right.source) and Enum.all?(right.ops, &worker_safe_op?/1) do
218+
# Worker-safe — pass through
219+
do_preprocess(rest, [op | processed], broadcast_names, workers, timeout, threshold)
176220
else
177-
# Too large — fall back to shuffle.
178-
# We can't inline this into the op list; the shuffle needs to execute
179-
# as a separate coordinated stage. For now, raise with a clear message.
180-
# Full shuffle-join integration is tracked as brief 2f (future work).
181-
raise ArgumentError,
182-
"right side of distributed join is too large to broadcast " <>
183-
"(#{div(byte_size(right_ipc), 1024 * 1024)}MB > " <>
184-
"#{div(@broadcast_threshold, 1024 * 1024)}MB threshold). " <>
185-
"Use Dux.Remote.Shuffle.execute/3 directly for large-large joins."
221+
# Non-worker-safe — compute right side and decide broadcast vs shuffle
222+
right_computed = Dux.compute(right)
223+
{:table, right_ref} = right_computed.source
224+
right_ipc = Dux.Native.table_to_ipc(right_ref)
225+
226+
if byte_size(right_ipc) <= threshold do
227+
# Small → broadcast
228+
broadcast_name = "__bcast_#{:erlang.unique_integer([:positive])}"
229+
broadcast_to_workers(workers, broadcast_name, right_ipc, timeout)
230+
broadcast_right = Dux.from_query("SELECT * FROM #{qi(broadcast_name)}")
231+
new_op = {:join, broadcast_right, how, on_cols, suffix}
232+
233+
do_preprocess(
234+
rest,
235+
[new_op | processed],
236+
[broadcast_name | broadcast_names],
237+
workers,
238+
timeout,
239+
threshold
240+
)
241+
else
242+
# Large → shuffle. Split the pipeline here.
243+
{:shuffle, Enum.reverse(processed),
244+
{right_computed, how, on_cols, suffix}, rest, broadcast_names}
245+
end
186246
end
187247
end
188248

249+
# Non-join op: pass through
250+
defp do_preprocess([op | rest], processed, broadcast_names, workers, timeout, threshold) do
251+
do_preprocess(rest, [op | processed], broadcast_names, workers, timeout, threshold)
252+
end
253+
189254
defp broadcast_to_workers(workers, name, ipc_binary, _timeout) do
190255
tasks =
191256
Enum.map(workers, fn worker ->

test/dux/distributed_join_routing_test.exs

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,9 @@ defmodule Dux.DistributedJoinRoutingTest do
329329
|> Dux.distribute(workers)
330330

331331
# Self-join: the right side is the same pipeline but computed locally
332-
right = Dux.from_list([%{id: 1, parent: 2}, %{id: 2, parent: 3}, %{id: 3, parent: nil}])
333-
|> Dux.compute()
332+
right =
333+
Dux.from_list([%{id: 1, parent: 2}, %{id: 2, parent: 3}, %{id: 3, parent: nil}])
334+
|> Dux.compute()
334335

335336
result =
336337
df
@@ -343,4 +344,125 @@ defmodule Dux.DistributedJoinRoutingTest do
343344
assert ids == [1, 2]
344345
end
345346
end
347+
348+
# ---------------------------------------------------------------------------
349+
# Shuffle join: distributed left + large local right
350+
# ---------------------------------------------------------------------------
351+
352+
describe "shuffle join (forced via broadcast_threshold: 0)" do
353+
test "large right side triggers shuffle join" do
354+
workers = start_workers(2)
355+
356+
left_data = Enum.map(1..20, &%{id: &1, val: &1 * 10})
357+
right_data = Enum.map(1..20, &%{id: &1, tag: "item_#{&1}"})
358+
359+
left = Dux.from_list(left_data)
360+
right = Dux.from_list(right_data) |> Dux.compute()
361+
362+
# Force shuffle by setting broadcast threshold to 0
363+
result =
364+
left
365+
|> Dux.join(right, on: :id)
366+
|> Dux.Remote.Coordinator.execute(
367+
workers: workers,
368+
broadcast_threshold: 0
369+
)
370+
|> Dux.sort_by(:id)
371+
|> Dux.to_rows()
372+
373+
# All 20 ids should match
374+
ids = Enum.map(result, & &1["id"])
375+
assert Enum.sort(ids) == Enum.to_list(1..20)
376+
assert Enum.all?(result, &String.starts_with?(&1["tag"], "item_"))
377+
end
378+
379+
test "shuffle join matches local join" do
380+
workers = start_workers(2)
381+
382+
left_data = Enum.map(1..10, &%{key: &1, left_val: &1 * 10})
383+
right_data = [%{key: 2, right_val: 200}, %{key: 5, right_val: 500}, %{key: 8, right_val: 800}]
384+
385+
left = Dux.from_list(left_data)
386+
right = Dux.from_list(right_data) |> Dux.compute()
387+
388+
# Local join for comparison
389+
local =
390+
left
391+
|> Dux.join(right, on: :key)
392+
|> Dux.sort_by(:key)
393+
|> Dux.to_rows()
394+
395+
# Shuffle join
396+
shuffled =
397+
left
398+
|> Dux.join(right, on: :key)
399+
|> Dux.Remote.Coordinator.execute(
400+
workers: workers,
401+
broadcast_threshold: 0
402+
)
403+
|> Dux.sort_by(:key)
404+
|> Dux.to_rows()
405+
406+
# Same keys should match
407+
local_keys = Enum.map(local, & &1["key"])
408+
shuffle_keys = Enum.map(shuffled, & &1["key"])
409+
assert local_keys == shuffle_keys
410+
411+
# Same right values
412+
local_vals = Enum.map(local, & &1["right_val"])
413+
shuffle_vals = Enum.map(shuffled, & &1["right_val"])
414+
assert local_vals == shuffle_vals
415+
end
416+
417+
test "left join via shuffle" do
418+
workers = start_workers(2)
419+
420+
left = Dux.from_list([%{id: 1}, %{id: 2}, %{id: 3}])
421+
right = Dux.from_list([%{id: 1, name: "Alice"}]) |> Dux.compute()
422+
423+
result =
424+
left
425+
|> Dux.join(right, on: :id, how: :left)
426+
|> Dux.Remote.Coordinator.execute(
427+
workers: workers,
428+
broadcast_threshold: 0
429+
)
430+
|> Dux.sort_by(:id)
431+
|> Dux.to_rows()
432+
433+
# All 3 left rows should be present
434+
ids = Enum.map(result, & &1["id"]) |> Enum.sort()
435+
assert ids == [1, 2, 3]
436+
437+
# Only id=1 should have a name
438+
matched = Enum.filter(result, &(&1["name"] != nil))
439+
assert length(matched) == 1
440+
assert hd(matched)["name"] == "Alice"
441+
end
442+
443+
test "shuffle with ops before the join" do
444+
workers = start_workers(2)
445+
446+
left = Dux.from_list(Enum.map(1..10, &%{id: &1, val: &1 * 10}))
447+
right = Dux.from_list([%{id: 3, tag: "three"}, %{id: 7, tag: "seven"}]) |> Dux.compute()
448+
449+
# Filter before the join — these ops execute as stage 1
450+
result =
451+
left
452+
|> Dux.filter_with("val > 20")
453+
|> Dux.join(right, on: :id)
454+
|> Dux.Remote.Coordinator.execute(
455+
workers: workers,
456+
broadcast_threshold: 0
457+
)
458+
|> Dux.sort_by(:id)
459+
|> Dux.to_rows()
460+
461+
# Only ids 3 and 7 match (both have val > 20)
462+
# With replicated source, 2 workers each produce the matching rows
463+
ids = Enum.map(result, & &1["id"]) |> Enum.uniq() |> Enum.sort()
464+
assert ids == [3, 7]
465+
assert Enum.all?(result, &Map.has_key?(&1, "tag"))
466+
end
467+
end
346468
end

0 commit comments

Comments
 (0)