@@ -19,7 +19,7 @@ defmodule Dux.Remote.Coordinator do
1919 The result is a `%Dux{}` struct with the merged data.
2020 """
2121
22- alias Dux.Remote . { Merger , Partitioner , PipelineSplitter , Worker }
22+ alias Dux.Remote . { Merger , Partitioner , PipelineSplitter , Shuffle , Worker }
2323 import Dux.SQL.Helpers , only: [ qi: 1 ]
2424
2525 # Broadcast threshold: 256MB serialized Arrow IPC
@@ -43,6 +43,7 @@ defmodule Dux.Remote.Coordinator do
4343 workers = Keyword . get_lazy ( opts , :workers , & Worker . list / 0 )
4444 timeout = Keyword . get ( opts , :timeout , :infinity )
4545 strategy = Keyword . get ( opts , :strategy , :round_robin )
46+ bcast_threshold = Keyword . get ( opts , :broadcast_threshold , @ broadcast_threshold )
4647
4748 if workers == [ ] do
4849 raise ArgumentError , "no workers available for distributed execution"
@@ -53,36 +54,29 @@ defmodule Dux.Remote.Coordinator do
5354 PipelineSplitter . split ( pipeline . ops )
5455
5556 # Preprocess joins: broadcast/shuffle right sides that aren't worker-safe
56- { processed_ops , broadcast_tables } = preprocess_joins ( worker_ops , workers , timeout )
57+ case preprocess_joins ( worker_ops , workers , timeout , bcast_threshold ) do
58+ { :ok , processed_ops , broadcast_tables } ->
59+ # All joins handled inline (broadcast or push-down)
60+ worker_pipeline = % { pipeline | ops: processed_ops }
5761
58- worker_pipeline = % { pipeline | ops: processed_ops }
59-
60- try do
61- # Partition the worker pipeline across workers
62- assignments = Partitioner . assign ( worker_pipeline , workers , strategy: strategy )
63-
64- # Fan out: each worker executes its partition
65- results = fan_out ( assignments , timeout )
66-
67- # Collect successful results, handle failures
68- { successes , failures } = partition_results ( results )
69-
70- if successes == [ ] do
71- reasons = Enum . map ( failures , fn { :error , reason } -> reason end )
72- raise ArgumentError , "all workers failed: #{ inspect ( reasons ) } "
73- end
74-
75- # Merge partial results on coordinator
76- merged = Merger . merge_to_dux ( successes , worker_pipeline )
77-
78- # Apply AVG rewrites if any
79- merged = apply_avg_rewrites ( merged , rewrites )
62+ try do
63+ result = execute_fan_out ( worker_pipeline , workers , strategy , timeout )
64+ result = apply_avg_rewrites ( result , rewrites )
65+ apply_coordinator_ops ( result , coord_ops )
66+ after
67+ cleanup_broadcast_tables ( workers , broadcast_tables )
68+ end
8069
81- # Apply coordinator-only ops (slice, pivot, etc.)
82- apply_coordinator_ops ( merged , coord_ops )
83- after
84- # Clean up broadcast tables on all workers
85- cleanup_broadcast_tables ( workers , broadcast_tables )
70+ { :shuffle , ops_before , { right_computed , how , on_cols , suffix } , ops_after , broadcast_tables } ->
71+ # Pipeline needs a shuffle stage: execute pre-join → shuffle → post-join
72+ try do
73+ execute_with_shuffle (
74+ pipeline , ops_before , right_computed , how , on_cols , suffix ,
75+ ops_after , coord_ops , rewrites , workers , strategy , timeout
76+ )
77+ after
78+ cleanup_broadcast_tables ( workers , broadcast_tables )
79+ end
8680 end
8781 end
8882
@@ -103,6 +97,64 @@ defmodule Dux.Remote.Coordinator do
10397 # Internal
10498 # ---------------------------------------------------------------------------
10599
100+ # Standard distributed execution: partition → fan out → merge
101+ defp execute_fan_out ( worker_pipeline , workers , strategy , timeout ) do
102+ assignments = Partitioner . assign ( worker_pipeline , workers , strategy: strategy )
103+ results = fan_out ( assignments , timeout )
104+ { successes , failures } = partition_results ( results )
105+
106+ if successes == [ ] do
107+ reasons = Enum . map ( failures , fn { :error , reason } -> reason end )
108+ raise ArgumentError , "all workers failed: #{ inspect ( reasons ) } "
109+ end
110+
111+ Merger . merge_to_dux ( successes , worker_pipeline )
112+ end
113+
114+ # Multi-stage execution for shuffle joins:
115+ # 1. Execute ops before the join as a distributed query
116+ # 2. Shuffle join the result with the right side
117+ # 3. Apply remaining ops + coordinator ops
118+ defp execute_with_shuffle (
119+ pipeline , ops_before , right_computed , how , on_cols , _suffix ,
120+ ops_after , coord_ops , rewrites , workers , strategy , timeout
121+ ) do
122+ # Stage 1: execute pre-join ops distributed
123+ left_result =
124+ if ops_before == [ ] do
125+ # No ops before join — just compute the source
126+ Dux . compute ( % { pipeline | ops: [ ] } )
127+ else
128+ pre_pipeline = % { pipeline | ops: ops_before }
129+ execute_fan_out ( pre_pipeline , workers , strategy , timeout )
130+ end
131+
132+ # Stage 2: shuffle join
133+ # Extract the join column for shuffle (uses first column pair)
134+ { left_col , _right_col } = hd ( on_cols )
135+
136+ shuffle_result =
137+ Shuffle . execute ( left_result , right_computed ,
138+ on: String . to_atom ( left_col ) ,
139+ how: how ,
140+ workers: workers ,
141+ timeout: timeout
142+ )
143+
144+ # Stage 3: apply remaining ops + rewrites + coordinator ops
145+ # Any ops after the join in the worker list need to run on the shuffle result
146+ result =
147+ if ops_after == [ ] do
148+ shuffle_result
149+ else
150+ post_pipeline = % { shuffle_result | ops: ops_after }
151+ Dux . compute ( post_pipeline )
152+ end
153+
154+ result = apply_avg_rewrites ( result , rewrites )
155+ apply_coordinator_ops ( result , coord_ops )
156+ end
157+
106158 defp fan_out ( assignments , timeout ) do
107159 # Execute in parallel via Task.async_stream
108160 assignments
@@ -136,56 +188,69 @@ defmodule Dux.Remote.Coordinator do
136188 # ---------------------------------------------------------------------------
137189
138190 # Scan worker ops for joins where the right side holds a source that
139- # workers can't resolve (e.g. a local {:table, ref}). For each such join,
140- # compute the right side on the coordinator, broadcast it to all workers,
141- # and replace the join op with a join against the broadcast table.
191+ # workers can't resolve (e.g. a local {:table, ref}). For each such join:
142192 #
143- # Returns {processed_ops, broadcast_table_names} where broadcast_table_names
144- # is a list of names to clean up after query execution.
145- defp preprocess_joins ( ops , workers , timeout ) do
146- Enum . map_reduce ( ops , [ ] , fn
147- { :join , % Dux { } = right , how , on_cols , suffix } , broadcast_names ->
148- if worker_safe_source? ( right . source ) and Enum . all? ( right . ops , & worker_safe_op? / 1 ) do
149- # Right side is resolvable on workers — pass through unchanged
150- { { :join , right , how , on_cols , suffix } , broadcast_names }
151- else
152- # Right side has a local source — need to broadcast or shuffle
153- route_join ( right , how , on_cols , suffix , workers , timeout , broadcast_names )
154- end
193+ # - Small right side → broadcast to workers, replace join with broadcast ref
194+ # - Large right side → signal a pipeline split for shuffle join
195+ #
196+ # Returns:
197+ # {:ok, processed_ops, broadcast_names} — all joins handled inline
198+ # {:shuffle, ops_before, join_info, ops_after, broadcast_names} — need a shuffle stage
199+ defp preprocess_joins ( ops , workers , timeout , threshold ) do
200+ do_preprocess ( ops , [ ] , [ ] , workers , timeout , threshold )
201+ end
155202
156- op , broadcast_names ->
157- { op , broadcast_names }
158- end )
203+ # Base case: all ops processed, no shuffle needed
204+ defp do_preprocess ( [ ] , processed , broadcast_names , _workers , _timeout , _threshold ) do
205+ { :ok , Enum . reverse ( processed ) , broadcast_names }
159206 end
160207
161- # Route a join with a non-worker-safe right side to broadcast or shuffle.
162- defp route_join ( right , how , on_cols , suffix , workers , timeout , broadcast_names ) do
163- # Compute the right side on the coordinator
164- right_computed = Dux . compute ( right )
165- { :table , right_ref } = right_computed . source
166- right_ipc = Dux.Native . table_to_ipc ( right_ref )
167-
168- if byte_size ( right_ipc ) <= @ broadcast_threshold do
169- # Small enough to broadcast
170- broadcast_name = "__bcast_#{ :erlang . unique_integer ( [ :positive ] ) } "
171- broadcast_to_workers ( workers , broadcast_name , right_ipc , timeout )
172-
173- # Replace right side with a query against the broadcast table
174- broadcast_right = Dux . from_query ( "SELECT * FROM #{ qi ( broadcast_name ) } " )
175- { { :join , broadcast_right , how , on_cols , suffix } , [ broadcast_name | broadcast_names ] }
208+ # Join op: check if right side is worker-safe
209+ defp do_preprocess (
210+ [ { :join , % Dux { } = right , how , on_cols , suffix } = op | rest ] ,
211+ processed ,
212+ broadcast_names ,
213+ workers ,
214+ timeout ,
215+ threshold
216+ ) do
217+ if worker_safe_source? ( right . source ) and Enum . all? ( right . ops , & worker_safe_op? / 1 ) do
218+ # Worker-safe — pass through
219+ do_preprocess ( rest , [ op | processed ] , broadcast_names , workers , timeout , threshold )
176220 else
177- # Too large — fall back to shuffle.
178- # We can't inline this into the op list; the shuffle needs to execute
179- # as a separate coordinated stage. For now, raise with a clear message.
180- # Full shuffle-join integration is tracked as brief 2f (future work).
181- raise ArgumentError ,
182- "right side of distributed join is too large to broadcast " <>
183- "(#{ div ( byte_size ( right_ipc ) , 1024 * 1024 ) } MB > " <>
184- "#{ div ( @ broadcast_threshold , 1024 * 1024 ) } MB threshold). " <>
185- "Use Dux.Remote.Shuffle.execute/3 directly for large-large joins."
221+ # Non-worker-safe — compute right side and decide broadcast vs shuffle
222+ right_computed = Dux . compute ( right )
223+ { :table , right_ref } = right_computed . source
224+ right_ipc = Dux.Native . table_to_ipc ( right_ref )
225+
226+ if byte_size ( right_ipc ) <= threshold do
227+ # Small → broadcast
228+ broadcast_name = "__bcast_#{ :erlang . unique_integer ( [ :positive ] ) } "
229+ broadcast_to_workers ( workers , broadcast_name , right_ipc , timeout )
230+ broadcast_right = Dux . from_query ( "SELECT * FROM #{ qi ( broadcast_name ) } " )
231+ new_op = { :join , broadcast_right , how , on_cols , suffix }
232+
233+ do_preprocess (
234+ rest ,
235+ [ new_op | processed ] ,
236+ [ broadcast_name | broadcast_names ] ,
237+ workers ,
238+ timeout ,
239+ threshold
240+ )
241+ else
242+ # Large → shuffle. Split the pipeline here.
243+ { :shuffle , Enum . reverse ( processed ) ,
244+ { right_computed , how , on_cols , suffix } , rest , broadcast_names }
245+ end
186246 end
187247 end
188248
249+ # Non-join op: pass through
250+ defp do_preprocess ( [ op | rest ] , processed , broadcast_names , workers , timeout , threshold ) do
251+ do_preprocess ( rest , [ op | processed ] , broadcast_names , workers , timeout , threshold )
252+ end
253+
189254 defp broadcast_to_workers ( workers , name , ipc_binary , _timeout ) do
190255 tasks =
191256 Enum . map ( workers , fn worker ->
0 commit comments