@@ -42,22 +42,25 @@ pub struct OptimizeInputsTask {
4242 expr_id : ExprId ,
4343 continue_from : Option < ContinueTask > ,
4444 pruning : bool ,
45+ upper_bound : Option < f64 > ,
4546}
4647
4748impl OptimizeInputsTask {
48- pub fn new ( expr_id : ExprId , pruning : bool ) -> Self {
49+ pub fn new ( expr_id : ExprId , pruning : bool , upper_bound : Option < f64 > ) -> Self {
4950 Self {
5051 expr_id,
5152 continue_from : None ,
5253 pruning,
54+ upper_bound,
5355 }
5456 }
5557
56- fn continue_from ( & self , cont : ContinueTask , pruning : bool ) -> Self {
58+ fn continue_from ( & self , cont : ContinueTask , pruning : bool , upper_bound : Option < f64 > ) -> Self {
5759 Self {
5860 expr_id : self . expr_id ,
5961 continue_from : Some ( cont) ,
6062 pruning,
63+ upper_bound,
6164 }
6265 }
6366
@@ -153,6 +156,19 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
153156
154157 trace ! ( event = "task_begin" , task = "optimize_inputs" , expr_id = %self . expr_id, continue_from = %ContinueTaskDisplay ( & self . continue_from) , total_children = %children_group_ids. len( ) ) ;
155158
159+ let upper_bound = if self . pruning {
160+ if let Some ( upper_bound) = self . upper_bound {
161+ Some ( upper_bound)
162+ } else if let Some ( winner) = optimizer. get_group_info ( group_id) . winner . as_full_winner ( )
163+ {
164+ Some ( winner. total_weighted_cost )
165+ } else {
166+ None
167+ }
168+ } else {
169+ None
170+ } ;
171+
156172 if let Some ( ContinueTask {
157173 next_group_idx,
158174 return_from_optimize_group,
@@ -219,9 +235,9 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
219235 winner_weighted_cost = %trace_fmt( & group_info. winner) ,
220236 current_processing = %next_group_idx,
221237 total_child_groups = %children_group_ids. len( ) ) ;
222- if let Some ( winner ) = group_info . winner . as_full_winner ( ) {
238+ if let Some ( upper_bound ) = upper_bound {
223239 let cost_so_far = cost. weighted_cost ( & total_cost) ;
224- if winner . total_weighted_cost <= cost_so_far {
240+ if upper_bound <= cost_so_far {
225241 trace ! ( event = "task_finish" , task = "optimize_inputs" , expr_id = %self . expr_id, result = "pruned" ) ;
226242 return Ok ( vec ! [ ] ) ;
227243 }
@@ -232,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
232248 let child_group_id = children_group_ids[ next_group_idx] ;
233249 let group_idx = next_group_idx;
234250 let child_group_info = optimizer. get_group_info ( child_group_id) ;
235- if ! child_group_info. winner . has_full_winner ( ) {
251+ let Some ( child_winner ) = child_group_info. winner . as_full_winner ( ) else {
236252 if !return_from_optimize_group {
237253 trace ! ( event = "task_yield" , task = "optimize_inputs" , expr_id = %self . expr_id, group_idx = %group_idx, yield_to = "optimize_group" , optimize_group_id = %child_group_id) ;
238254 return Ok ( vec ! [
@@ -242,22 +258,25 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
242258 return_from_optimize_group: true ,
243259 } ,
244260 self . pruning,
261+ upper_bound,
245262 ) ) as Box <dyn Task <T , M >>,
246- Box :: new( OptimizeGroupTask :: new( child_group_id) ) as Box <dyn Task <T , M >>,
263+ Box :: new( OptimizeGroupTask :: new( child_group_id, upper_bound) )
264+ as Box <dyn Task <T , M >>,
247265 ] ) ;
248266 } else {
249267 self . update_winner_impossible ( optimizer) ;
250268 trace ! ( event = "task_finish" , task = "optimize_inputs" , expr_id = %self . expr_id, result = "impossible" ) ;
251269 return Ok ( vec ! [ ] ) ;
252270 }
253- }
271+ } ;
254272 trace ! ( event = "task_yield" , task = "optimize_inputs" , expr_id = %self . expr_id, group_idx = %group_idx, yield_to = "next_optimize_input" ) ;
255273 Ok ( vec ! [ Box :: new( self . continue_from(
256274 ContinueTask {
257275 next_group_idx: group_idx + 1 ,
258276 return_from_optimize_group: false ,
259277 } ,
260278 self . pruning,
279+ upper_bound. map( |bound| bound - child_winner. total_weighted_cost) ,
261280 ) ) as Box <dyn Task <T , M >>] )
262281 } else {
263282 self . update_winner ( input_statistics_ref, operation_cost, total_cost, optimizer) ;
@@ -272,6 +291,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
272291 return_from_optimize_group: false ,
273292 } ,
274293 self . pruning,
294+ upper_bound,
275295 ) ) as Box <dyn Task <T , M >>] )
276296 }
277297 }
0 commit comments