-
Notifications
You must be signed in to change notification settings - Fork 56
Relax AST Design
Authors(alphabetical): @altanh, @electriclilies, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang
This doc is meant to serve as a design overview of the Relax AST. For a broad background of Relax, please refer to Relax Architecture Overview.
To support the key goals (G0: support dynamic shape workloads, and G1: dataflow block as a first class citizen) in the architecture overview, Relax adds the following constructs to the AST.
class ShapeExpr(Expr):
"""corresponds to a shape containing symbolic PrimExpr"""
values: List[PrimExpr]
class Var(Expr):
"""global scope visible vars"""
vid: Id
type_annotation: Optional[Type]
class DataflowVar(Var):
"""dataflow scope visible vars"""
pass
class Binding(ObjectRef):
"""the base class of bindings"""
pass
class VarBinding(Binding):
"""variable bindings, bind the value to the var"""
var: Var
value: Expr
class MatchShape(Binding):
"""binding represents to match a shape"""
value: Expr
pattern: List[PrimExpr]
var: Var
class BindingBlock(Node):
"""base class of binding block, bindings inside can be impure (with side effect or control flow)"""
bindings: List[Binding]
class DataflowBlock(BindingBlock):
"""dataflow block, bindings inside are pure (no side effect and no control flow)"""
pass
class SeqExpr(Expr):
"""sequence of BindingBlocks, can serve as the body of a Function"""
blocks: List[BindingBlock]
body: Expr
class Function(BaseFunc):
"""represents a Relax function"""
params: List[Var]
body: Expr
ret_type: Type
class ExternFunc(BaseFunc):
"""extern function, which can represent a TIR PrimFunc or a PackedFunc."""
global_symbol: String
- A
Function's body can be aSeqExpr. - A
SeqExprconsists of a list ofBindingBlock. -
DataflowBlockis a special kind ofBindingBlockthat is identical to a pure computational graph. The bindings insideDataflowBlockhave no side effects and no control. - A
BindingBlockconsists of a list ofBinding. -
Bindingcan be eitherVarBindingorMatchShape. - The scope of a
DataflowVaris itsDataflowBlock, a normalVarin aDataflowBlockescapes to the scope containing the block (which could be the function scope or some other scope like an if branch). Note that TIR vars (bound byMatchShape) have the same scoping rules as normalVars.
Let's take the following relax program as an example, relax_func contains a SeqExpr, the SeqExpr contains a DataflowBlock (with 2 VarBinding) and a BindingBlock with one VarBinding.
from tvm.script import relax as R
@R.func
def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[(k, m), "f32"]):
# start a DataflowBlock
with R.dataflow(): ## <= DataflowBlock
lv0: R.Tensor[(n, m), "f32"] = R.dot(x, w) ## <= VarBinding, lv0 is a DataflowVar
gv0: R.Tensor[(n * m,), "f32"] = R.flatten(lv0) ## <= VarBinding, gv0 is a Var that escapes to the outer scope
R.outputs(gv0)
# start a BindingBlock
gv1 = R.call_packed("custom_inplace_update", gv0) ## <= side-effect binding
return gv1Most pass writers are ML researcher and ML engineers who have no compiler or PL background, so they write passes based on the simple assumption that the passes are mutating a pure computational graph. Relay is not explicit about which expressions have side-effects vs. which are pure, as a result, many optimizations are unsound in the presence of side-effects. In Relax, DataflowBlock represents a computational graph region where all the bindings inside are pure (no side effects, no control flow). Clearly separating the graph region and having it as first-class citizen makes it easy for end-users to write graph passes. Due to this clear separation between the "pure" and "impure" regions, a Function's body can be composed of one or more pure or impure blocks, so SeqExpr's body comes with Array<BindingBlock>.
MatchShape(value: Expr, pattern: List[PrimExpr], var: Var)
-
MatchShapehas two overloaded semantics:- Suppose
xis a 2-D tensor:- (1)
MatchShape(x.shape, [m, n], var)→ matchesx.shapeto the symbolic variables (m, n), and returns a Shape to the return var; - (2)
MatchShape(x, [m, n], var)→ matchesx.shapeto the symbolic variables (m, n), and returns a 2-D tensor with the same shape as tensor x (but with explicit shape field[m, n]) to the output var.
- (1)
- Suppose
-
DataflowBlockis a self-contained data structure which contains a list of bindings. Pass writers can visit and transform a pure dataflow block using theDataflowMutatorinterface. It could be more user-friendly to those pass writers with only ML background who are familiar with computational dataflow because they only need to face this simple concept ofDataflowBlockand override visitors inDataflowMutator. - A
BindingBlockis a list ofBinding. TheExprMutatorworks on ANF program, so the visitor can traverse the bindings without the need of memoization and there is no stack overflow since there is no recursion. - The
ExprMutatorhas an internalBlockBuilderthat can emit bindings to the newly created block. Why having a internalBlockBuilderin theExprMutator?- BlockBuilder provides APIs for emitting bindings to a block. We can often see the cases where we want to fold several bindings into one (n → 1) or we want to rewrite a binding to multiple bindings (1 → n). Using BlockBuilder to emit bindings in the visitor can easily do both.
- The BlockBuilder can do eager shape and type inference, so the
shape_andchecked_type_fields of both lhs var and rhs expr can be filled when emitting new bindings.
TVM performs IR analysis and rewriting through a recursive-descent visitor pattern over the expression's abstract syntax tree. For example, to write an analysis pass for counting the number of relax.Var definitions within a Relax.Function, we can overload relax::ExprVisitor via class CountVarDef : public relax::ExprVisitor.
The hierachy of the default visitor pattern is shown below. To count the number of variable defines, we only need to overload the VisitVarDef function:
class CountVarDef : public relax::ExprVisitor {
public:
size_t n_vardef {0};
void VisitVarDef(const Var& var) override { ++n_vardef; };
void VisitExpr(const Expr& expr) override { n_vardef = 0; ExprVisitor::VisitExpr(expr); };
};
// Export to Python front end.
TVM_REGISTER_GLOBAL(("relax.analysis.count_vardef")).set_body_typed([](Function f) {
CountVarDef counter{};
counter.VisitExpr(f);
return counter.n_vardef;
});Similarly, if we want to only count the number of DataflowVar definitions, according to the figure, we only need to overload VisitVarDef_(const DataflowVarNode*) as by default (ExprVisitor's implementation) VisitVarDef will automatically and dynamically dispatch the corresponding VisitVarDef_ according to the variable's type (Var or DataflowVar. See also the top-left blue blocks).