@@ -12,7 +12,13 @@ impl<M: Memo> Optimizer<M> {
1212 /// 1. Computes newly discovered expressions by subtracting existing ones.
1313 /// 2. For each fork task in the group, launches continuation tasks for each new expression.
1414 /// 3. If the task is the principal, launches transform tasks for each new expression and rule.
15- /// 4. Updates the task's dispatched expressions with the full input set.
15+ /// 4. For all related optimize goal tasks, launch implement tasks for each new expression.
16+ /// *NOTE*: This happens before merging goals. While this might be slightly inefficient, as
17+ /// we might implement twice for the same "soon-to-be" merged goals. However it keeps the code
18+ /// cleaner and easier to understand. Also, the performance impact is negligible, as once we
19+ /// merge the goals, we effectively delete the implement tasks and its associated jobs will
20+ /// *not* be launched.
21+ /// 5. Updates the task's dispatched expressions with the full input set.
1622 ///
1723 /// # Arguments
1824 /// * `task_id` - The ID of the group exploration task to update.
@@ -26,45 +32,67 @@ impl<M: Memo> Optimizer<M> {
2632 ) -> Result < ( ) , M :: MemoError > {
2733 let new_exprs = self . compute_new_expressions ( task_id, all_logical_exprs) ;
2834
29- if !new_exprs. is_empty ( ) {
30- let ( group_id, fork_tasks) = {
31- let task = self . get_explore_group_task ( task_id) . unwrap ( ) ;
32- ( task. group_id , task. fork_logical_out . clone ( ) )
33- } ;
35+ if new_exprs. is_empty ( ) {
36+ return Ok ( ( ) ) ;
37+ }
3438
35- for & fork_task_id in & fork_tasks {
36- let continuation = self
37- . get_fork_logical_task ( fork_task_id)
38- . unwrap ( )
39- . continuation
40- . clone ( ) ;
39+ let ( group_id, fork_tasks, optimize_goal_tasks) = {
40+ let task = self . get_explore_group_task ( task_id) . unwrap ( ) ;
41+ (
42+ task. group_id ,
43+ task. fork_logical_out . clone ( ) ,
44+ task. optimize_goal_out . clone ( ) ,
45+ )
46+ } ;
4147
42- let continuation_tasks = self . create_logical_cont_tasks (
43- & new_exprs,
44- group_id,
45- fork_task_id,
46- & continuation,
47- ) ;
48+ // For each fork task, create continuation tasks for each new expression.
49+ fork_tasks. iter ( ) . for_each ( |& fork_task_id| {
50+ let continuation = self
51+ . get_fork_logical_task ( fork_task_id)
52+ . unwrap ( )
53+ . continuation
54+ . clone ( ) ;
4855
49- self . get_fork_logical_task_mut ( fork_task_id)
50- . unwrap ( )
51- . continue_with_logical_in
52- . extend ( continuation_tasks) ;
53- }
56+ let continuation_tasks =
57+ self . create_logical_cont_tasks ( & new_exprs, group_id, fork_task_id, & continuation) ;
5458
55- if principal {
56- let transform_tasks = self . create_transform_tasks ( & new_exprs, group_id, task_id) ;
57- self . get_explore_group_task_mut ( task_id)
58- . unwrap ( )
59- . transform_expr_in
60- . extend ( transform_tasks) ;
61- }
59+ self . get_fork_logical_task_mut ( fork_task_id)
60+ . unwrap ( )
61+ . continue_with_logical_in
62+ . extend ( continuation_tasks) ;
63+ } ) ;
64+
65+ // For each optimize goal task, create implement tasks for each new expression.
66+ optimize_goal_tasks. iter ( ) . for_each ( |& optimize_goal_id| {
67+ let goal_id = self
68+ . get_optimize_goal_task ( optimize_goal_id)
69+ . unwrap ( )
70+ . goal_id ;
71+
72+ let implement_tasks =
73+ self . create_implement_tasks ( & new_exprs, goal_id, optimize_goal_id) ;
74+
75+ self . get_optimize_goal_task_mut ( optimize_goal_id)
76+ . unwrap ( )
77+ . implement_expression_in
78+ . extend ( implement_tasks) ;
79+ } ) ;
80+
81+ // For the principal task, create transform tasks for each new expression.
82+ // We could always do it, but this is a straightforward optimization.
83+ if principal {
84+ let transform_tasks = self . create_transform_tasks ( & new_exprs, group_id, task_id) ;
6285
6386 self . get_explore_group_task_mut ( task_id)
6487 . unwrap ( )
65- . dispatched_exprs = all_logical_exprs. clone ( ) ;
88+ . transform_expr_in
89+ . extend ( transform_tasks) ;
6690 }
6791
92+ self . get_explore_group_task_mut ( task_id)
93+ . unwrap ( )
94+ . dispatched_exprs = all_logical_exprs. clone ( ) ;
95+
6896 Ok ( ( ) )
6997 }
7098
@@ -135,10 +163,7 @@ impl<M: Memo> Optimizer<M> {
135163 ///
136164 /// # Arguments
137165 /// * `task_id` - The ID of the group exploration task to deduplicate.
138- pub ( super ) async fn dedup_tasks (
139- & mut self ,
140- task_id : TaskId ,
141- ) -> Result < ( ) , M :: MemoError > {
166+ pub ( super ) async fn dedup_tasks ( & mut self , task_id : TaskId ) -> Result < ( ) , M :: MemoError > {
142167 let task = self . get_explore_group_task_mut ( task_id) . unwrap ( ) ;
143168 let old_exprs = std:: mem:: take ( & mut task. dispatched_exprs ) ;
144169 let transform_ids: Vec < _ > = task. transform_expr_in . iter ( ) . copied ( ) . collect ( ) ;
0 commit comments