11#include " SlidingWindow.h"
22
33#include " Bounds.h"
4+ #include " CSE.h"
45#include " CompilerLogger.h"
56#include " Debug.h"
67#include " ExprUsesVar.h"
@@ -86,7 +87,7 @@ class ExpandExpr : public IRMutator {
8687// Perform all the substitutions in a scope
8788Expr expand_expr (const Expr &e, const Scope<Expr> &scope) {
8889 ExpandExpr ee (scope);
89- Expr result = ee (e);
90+ Expr result = common_subexpression_elimination ( ee (e) );
9091 debug (4 ) << " Expanded " << e << " into " << result << " \n " ;
9192 return result;
9293}
@@ -223,6 +224,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
223224 Expr loop_min;
224225 set<int > &slid_dimensions;
225226 Scope<Expr> scope;
227+ Scope<Interval> &bounds_scope;
226228
227229 // For loops strictly between the loop being slid over and the current
228230 // node (not including the loop being slid over itself).
@@ -282,8 +284,8 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
282284 internal_assert (min_val && max_val);
283285 Expr min_req = *min_val;
284286 Expr max_req = *max_val;
285- min_req = expand_expr (min_req, scope);
286- max_req = expand_expr (max_req, scope);
287+ min_req = simplify ( expand_expr (min_req, scope), bounds_scope );
288+ max_req = simplify ( expand_expr (max_req, scope), bounds_scope );
287289
288290 debug (3 ) << func_args[i] << " :" << min_req << " , " << max_req << " \n " ;
289291 if (expr_depends_on_var (min_req, loop_var) ||
@@ -594,7 +596,10 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
594596 }
595597
596598 Stmt visit (const LetStmt *op) override {
597- ScopedBinding<Expr> bind (scope, op->name , simplify (expand_expr (op->value , scope)));
599+ ScopedBinding<Interval> bind_bounds (bounds_scope, op->name ,
600+ bounds_of_expr_in_scope (op->value , bounds_scope));
601+ ScopedBinding<Expr> bind (scope, op->name , simplify (expand_expr (op->value , scope), bounds_scope));
602+
598603 Stmt new_body = mutate (op->body );
599604
600605 Expr value = op->value ;
@@ -613,8 +618,10 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
613618 }
614619
615620public:
616- SlidingWindowOnFunctionAndLoop (Function f, string v, Expr v_min, set<int > &slid_dimensions)
617- : func(std::move(f)), loop_var(std::move(v)), loop_min(std::move(v_min)), slid_dimensions(slid_dimensions) {
621+ SlidingWindowOnFunctionAndLoop (Function f, string v, Expr v_min, set<int > &slid_dimensions,
622+ Scope<Interval> &bounds_scope)
623+ : func(std::move(f)), loop_var(std::move(v)), loop_min(std::move(v_min)),
624+ slid_dimensions (slid_dimensions), bounds_scope(bounds_scope) {
618625 }
619626
620627 Expr new_loop_min;
@@ -755,9 +762,16 @@ class SlidingWindow : public IRMutator {
755762 // Keep track of realizations we want to slide, from innermost to
756763 // outermost.
757764 list<Function> sliding;
765+ Scope<Interval> bounds_scope;
758766
759767 using IRMutator::visit;
760768
769+ Stmt visit (const LetStmt *op) override {
770+ ScopedBinding<Interval> bind (bounds_scope, op->name ,
771+ bounds_of_expr_in_scope (op->value , bounds_scope));
772+ return IRMutator::visit (op);
773+ }
774+
761775 Stmt visit (const Realize *op) override {
762776 // Find the args for this function
763777 map<string, Function>::const_iterator iter = env.find (op->name );
@@ -827,7 +841,14 @@ class SlidingWindow : public IRMutator {
827841
828842 set<int > &slid_dims = slid_dimensions[func.name ()];
829843 size_t old_slid_dims_size = slid_dims.size ();
830- SlidingWindowOnFunctionAndLoop slider (func, name, prev_loop_min, slid_dims);
844+
845+ Interval min_bounds = bounds_of_expr_in_scope (loop_min, bounds_scope);
846+ Interval max_bounds = bounds_of_expr_in_scope (loop_max, bounds_scope);
847+ ScopedBinding<Interval> bind_bounds (bounds_scope, op->name ,
848+ Interval (min_bounds.min , max_bounds.max ));
849+
850+ SlidingWindowOnFunctionAndLoop slider (func, name, prev_loop_min, slid_dims, bounds_scope);
851+
831852 body = slider (body);
832853
833854 if (func.schedule ().memory_type () == MemoryType::Register &&
0 commit comments