- Table of contents
- Preliminaries
- Primitive operations
- The syntax for %op
- The syntax for %cd
- Numeric and N-dimensional array literals
- Wildcard bindings
- Inline declarations
- Using OCANNL's generalized einsum notation
- Further features of the syntax extension %cd
- Further features of the syntax extension %op
- The syntax extension %%extend_dsls
- Implementation details
- In a nutshell
- Syntax extension
%cdstands for "code", to express assignments and computations:Assignments.comp. - Syntax extension
%opstands for "operation", to express tensors:Tensor.t. - Both extensions use record syntax
{ tensor_name }or{ tensor_name = init_expr }for inline tensor declarations. - Anti-quotation
%ocescapes expressions to preserve them as pure OCaml without transformation.
- Syntax extension
OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in arrayjit/ops.ml. We assign lexical operators to the binary operations, inventing novel operators if needed. For example, Rectified Linear Unit Relu operation, which computes f(x) = max(0,x), is called relu, while the ReLU-Gate Relu_gate operation, which computes f(x,y) = if x > 0.0 then y else 0.0, gets the operator -?/ in addition to name relu_gate. These built-in numeric operations are used to construct assignments (Assignments.t packaged as Assignments.comp). The syntax %cd is needed to build assignments concisely, and the assignment operators always start with = (unlike in C where they end with =). On the other hand, while the syntax %op helps build tensors (Tensor.t), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in tensor/operation.ml.
In OCANNL, we call a tensor that is prohibited from propagating gradients, does not have a gradient node nor backprop code, a non-differentiable tensor. Accordingly we can call the "plain" tensors with a gradient node differentiable tensors. Expressions in the %cd syntax will sometimes build new non-differentiable tensors as components of assignments (they will never build new differentiable tensors). The syntax extensions make the following assumption:
%cdassumes that any extension point will be in the scope of a moduleNTDSLthat provides at least the functionality ofOperation.NTDSL.%opassumes that any extension point will be in the scope of a moduleTDSLthat provides at least the functionality ofOperation.TDSL.%opwith inline definitions with inintialization expressions assumes that a modulePDSLis in scope with at least the functionality ofOperation.PDSL.- Both extensions assume
Tensor(from theOcannlwrapper) is in scope.
Functions inside Operation.NTDSL use ~grad_spec:Prohibit_grad when calling into Tensor, making the resulting tensors non-differentiable. Functions inside Operation.TDSL use ~grad_spec:If_needed, which will make the tensors non-differentiable when the gradient is not needed -- except for TDSL.param, which internally sets ~grad_spec:Require_grad. Functions inside Operation.PDSL use ~grad_spec:Require_grad.
The extension points open NTDSL.O, resp. TDSL.O, for the scope of the extension point, to expose the corresponding operators.
Within %op and %cd contexts, expressions typically undergo transformation to build tensors or assignments. However, OCANNL uses two mechanisms to preserve pure OCaml expressions:
In the %op syntax, when a function application contains a unit () argument, all arguments appearing before the unit are automatically preserved as pure OCaml expressions. This aligns with OCANNL's design pattern where configuration happens before the unit parameter:
(* Arguments before () are automatically preserved as OCaml *)
let%op my_fn ~label x =
other_fn ~label:(("prefix_" ^ name) :: label) ~config:value () x
(* label and config are preserved; x after () is transformed *)For cases where you need explicit control or the heuristic doesn't apply, the %oc (mnemonic: "OCaml") anti-quotation escapes from the transformation context:
(* Force preservation even after () or in edge cases *)
let%op special = process_data data [%oc complex_ocaml_expr]The %oc extension expects a single expression and returns it unchanged. Use cases:
- Overriding the unit-parameter heuristic when needed
- Preserving expressions in contexts without a unit parameter
- Escaping from the DSL in
%cdcontexts (which don't use the unit heuristic)
To accomodate stylistic preferences, OCANNL supports both curried and uncurried syntaxes for primitive operation application. Binary operators are associated with infix operators, in addition to having alphabetic identifiers. This stems from the following restriction: in the %cd syntax, the assignment is always an infix operator, and it needs to pick the accumulation operation.
The unary primitive operations:
| Identifier | Default projection | Constructor in Ir.Ops |
|---|---|---|
id |
pointwise | Identity |
relu |
pointwise | Relu |
sat01 |
pointwise | Satur01 |
exp |
pointwise | Exp |
log |
pointwise | Log |
exp2 |
pointwise | Exp2 |
log2 |
pointwise | Log2 |
sin |
pointwise | Sin |
cos |
pointwise | Cos |
sqrt |
pointwise | Sqrt |
recip |
pointwise | Recip |
recip_sqrt |
pointwise | Recip_sqrt |
neg |
pointwise | Neg |
tanh |
pointwise | Tanh_approx |
not |
pointwise | Not |
uint4x32_to_prec_uniform |
dedicated | Uint4x32_to_prec_uniform |
The binary primitive operations:
| Identifier | Infix operator | Default projection | Constructor in Ir.Ops |
Assignments |
|---|---|---|---|---|
fst |
-@> |
pointwise | Arg1 |
none |
snd |
-/> |
pointwise | Arg2 |
=: |
add |
+ |
pointwise | Add |
=+, =:+ |
sub |
- |
pointwise | Sub |
=-, =:- |
mul |
* |
none | Mul |
=*, =:* |
div |
/ |
none | Div |
=/, =:/ |
pow |
** |
pointwise | ToPowOf |
=**, =:** |
relu_gate |
-?/ |
pointwise | Relu_gate |
=?/, =:?/ |
sat01_gate |
-?^ |
pointwise | Satur01_gate |
=?^, =:?^ |
lt |
< |
pointwise | Cmplt |
none |
eq |
= |
pointwise | Cmpeq |
none |
ne |
<> |
pointwise | Cmpne |
none |
or_ |
|| |
pointwise | Or |
=||, =:|| |
and_ |
&& |
pointwise | And |
=&&, =:&& |
mod_ |
% |
pointwise | Mod |
none |
max |
@^ |
pointwise | Max |
=@^, =:@^ |
min |
@- |
pointwise | Min |
=@-, =:@- |
threefry4x32 |
^^^^ |
pointwise | Threefry4x32 |
=^^^^, =:^^^^ |
The ternary primitive operations:
| Identifier | Default projection | Constructor in Ir.Ops |
|---|---|---|
where |
pointwise | Where |
fma |
compose-accumulate | FMA |
The interpretation functions also state the semantics:
let interpret_unop op v =
let open Float in
match op with
| Identity -> v
| Relu when v >= 0. -> v
| Relu -> 0.
| Satur01 when v <= 0. -> 0.
| Satur01 when v >= 1. -> 1.
| Satur01 -> v
| Exp -> exp v
| Log -> log v
| Exp2 -> 2. ** v
| Log2 -> log v / log 2.
| Sin -> sin v
| Cos -> cos v
| Sqrt -> sqrt v
| Recip -> 1. / v
| Recip_sqrt -> 1. / sqrt v
| Neg -> ~-.v
| Tanh_approx -> tanh v
| Not -> if v = 0. then 1. else 0.
| Uint4x32_to_prec_uniform -> failwith "NOT IMPLEMENTED"
let interpret_binop op v1 v2 =
let open Float in
match op with
| Arg1 -> v1
| Arg2 -> v2
| Add -> v1 + v2
| Sub -> v1 - v2
| Mul -> v1 * v2
| Div -> v1 / v2
| ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
| ToPowOf -> v1 ** v2
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
| Satur01_gate -> if v1 > 0.0 && v1 < 1.0 then v2 else 0.0
| Max -> max v1 v2
| Min -> min v1 v2
| Mod -> v1 % v2
| Cmplt -> if v1 < v2 then 1. else 0.
| Cmpeq -> if v1 = v2 then 1. else 0.
| Cmpne -> if v1 <> v2 then 1. else 0.
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
| And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
| Threefry4x32 -> ...
let interpret_ternop op v1 v2 v3 =
let open Float in
match op with Where -> if v1 <> 0. then v2 else v3 | FMA -> (v1 * v2) + v3The %op syntax is simpler than the %cd syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:
let hid_dim = 8 in
let w = TDSL.param "w" in
let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
let layer x = TDSL.O.( relu(w * x + b) ) in
...Since TDSL.O is opened for the scope of an extension point %op:
let hid_dim = 8 in
let w = TDSL.param "w" in
let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
let%op layer x = relu(w * x + b) in
...Using inline declarations, this becomes more concise:
let hid_dim = 8 in
let%op mlp_layer x = relu({ w } * x + { b; o = [ hid_dim ] }) in
...When there is a function directly under the %op extension point, like in the example above, or directly under a function taking a unit parameter (), the function parameter (to the right of ()) should be a tensor. That's because %op uses this tensor's (value's) label to enrich the label of the resulting tensor.
When the declaration is followed by a literal float, the float provides the initial value to initialize the tensor. Otherwise, the tensor value cells are initialized randomly with uniform distribution.
The basic building blocks of the %cd syntax are individual assignments, separated by semicolons. The assignments, represented via Assignments.Accum_binop and Assignments.Accum_unop, are in full generality accumulating:
type Assignments.t =
...
| Accum_binop of {
initialize_neutral : bool;
accum : Ops.binop;
op : Ops.binop;
lhs : Tnode.t;
rhs1 : buffer;
rhs2 : buffer;
projections : Indexing.projections Lazy.t;
}
| Accum_unop of {
initialize_neutral : bool;
accum : Ops.binop;
op : Ops.unop;
lhs : Tnode.t;
rhs : buffer;
projections : Indexing.projections Lazy.t;
}For example the binary case in pseudocode: if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2) (assuming the neutral element of accum is 0). The representation also has a field projections which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
The basic %cd syntax for assignments has the form: <lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>. See Primitive operations for the syntax of primitive operation application, where <rhs1>, <rhs2> (for binary and ternary ops), <rhs3> (for ternary ops) are subexpressions. <asgn-op> starts with =, followed by : only if initialize_neutral is true, then followed by the operator syntax variant of a binary primitive operation. The fields <lhs>, <rhs1>, <rhs2>, <rhs3> will often be either special-purpose identifiers (specifically v, t, t1, t2, t3, g, g1, g2, g3) or identifiers bound to tensors. <rhs1>, <rsh2>, <rsh3> will also often be (non-differentiable) tensor expressions. The notation <tensor>.grad stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators +* and ++, see the section further features of the syntax extension %cd.
How is the projections field determined? projections can be given explicitly as a labeled argument ~projections. If they aren't but %cd realizes there is a ~projections parameter in scope, it uses it -- see tensor/operation.ml where this option is used to define tensor operations. If instead of ~projections a ~logic labeled argument is given, the string passed is used to determine projections. ~logic:"." means a pointwise operation. ~logic:"@" means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). ~logic:"T" means transpose of input and output axes. The string passed to ~logic can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.
Here we see an example of tensor multiplication -- extending matrix multiplication to arbitrary number of axes -- multiplying a by b to get c. In =:+, = is required to separate the assigned-to part from the computation, : clears-out c before the computation, + selects addition to accumulate the results.
c =:+ a * b ~logic:"@"Compare the following two ways of updating a parameter p:
p =+ learning_rate * p.grad ~logic:"."and:
p =+ learning_rate *. p.gradIn the first case, we have a binary assignment calculated pointwise. The resulting representation is Accum_binop where accum is Add and op is Mul (multiplication). In the second case, *. is not recognized as one of the built-in operators. This leaves the expression learning_rate *. p.grad un-transformed. Since (*.) is bound in NTDSL.O to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is Accum_unop where accum is Add and op is Identity. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
Advanced note: when a ~projections parameter is in scope but no assignment-specific ~projections argument is given -- the typical case in tensor/operation.ml -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers rhs1, t1, v1, g1 are "slot RHS1" of the projections, rhs2, t2, v2, g2 are "slot RHS2", lhs, t, v, g are "slot LHS". Scalar constants are provided the projection directly, to make the automated derivation more expressive; this is supported both for literals, and (heuristically) for !. and !.. embedding operators.
In addition to the special identifiers (t, t1, t2, lhs, rhs1, etc.), the %cd syntax can detect projection slots from identifier and inline tensor definition names using prefix/suffix patterns. This is essential when defining backpropagation code that needs intermediate tensors with specific projections and shapes.
The naming convention patterns:
| Prefix/Suffix | Detected Slot | Shape Source |
|---|---|---|
lhs_* or *_lhs |
LHS | output shape (from t) |
rhs_* or *_rhs |
RHS1 | first input shape (from t1) |
rhs1_* or *_rhs1 |
RHS1 | first input shape (from t1) |
rhs2_* or *_rhs2 |
RHS2 | second input shape (from t2) |
rhs3_* or *_rhs3 |
RHS3 | third input shape (from t3) |
This applies to:
- Inline tensor definitions:
{ cond_rhs1 }declares a tensor with slot RHS1 and shape oft1 - Identifier references: When
sum_rhs1is used in an expression, it's recognized as having slot RHS1
Why this matters for gradient computation: In operations like max pooling or tropical convolution, the gradient must flow back to positions that achieved the argmax. Using an intermediate tensor with the wrong shape causes incorrect gradients. For example, in a 4×4 → 2×2 max pooling:
*_lhssuffix gives shape 2×2 (output shape) — wrong for tracking per-input-position gradients*_rhs1suffix gives shape 4×4 (input shape) — correct for sparse gradient at argmax positions
Example from the tropical operation's gradient computation:
let%cd grad_asn ~t ~g ~t1 ~t2 ~projections =
(* Use _rhs1 suffix: gives input shape to track which positions achieved argmax *)
{ sum_rhs1 } =:@^ add (t1, t2); (* max over each input position's window *)
{ cond_rhs1 } =: eq (t, sum_rhs1); (* true where input+kernel achieved the argmax *)
g1 =+ where cond_rhs1 g 0; (* gradient flows to argmax input positions *)
g2 =+ where cond_rhs1 g 0 (* gradient flows to argmax kernel positions *)
inFor convolution-like operations with einsum "...|stride*oh<+wh, stride*ow<+ww, ..c..; wh, ww => ...|oh, ow, ..c..", the RHS1 index space (ih, iw) effectively encodes the outer product of output (oh, ow) and kernel (wh, ww) dimensions via ih = stride*oh + wh. This means using _rhs1 for intermediate condition tensors correctly tracks which (input position, kernel position) pair achieved the argmax for each output position.
Important: The naming convention affects both projection slot assignment and shape inference. In addition to determining which projection from ~projections to use when indexing the tensor, a shape equality constraint is generated between the inline-defined tensor and the corresponding operation tensor assumed to be in scope: t for *_lhs, t1 for *_rhs and *_rhs1, t2 for *_rhs2, etc. This means the shape from the tensor's initialization is unified with the shape of the operation component.
Both %cd and %op extensions use a shared syntax for N-dimensional array literals. %cd uses NTDSL.number and NTDSL.ndarray functions, while %op uses TDSL.number and TDSL.ndarray functions. (This is just for consistency: TDSL.ndarray invokes Tensor.ndarray ~grad_spec:If_needed, which will figure out the gradient is not needed and will make the tensor non-differentiable.)
Numbers are a special case: an array of (output) dimension 1.
N-dimensional array literals combine the list, tuple and array syntaxes to strictly distinguish between output, input and batch axes:
- The tuple syntax translates to an input axis.
- The list syntax translates to an output axis.
- The array syntax translates to a batch axis.
For example, [ (1, 2, 3); (4, 5, 6) ] is a mathematical matrix converting 3D vectors into 2D vectors.
OCANNL supports dimension labels. The syntax for number allows prefixing a number by a character that stands for the dimension label of the resulting output dimension 1. These labels can then propagate to specify labels of other dimensions in other tensors, via shape inference. Example: let%op y = ({ hey } * 'q' 2.0) + 'p' 1.0 in ...
When an extension is over a wildcard (ignore result) binding: let%cd _ = ... and let%op _ = ..., the generated code is wrapped in Tensor.with_unchanged_roots, to prevent it from upsetting rootness checks. The use-case for writing %op and %cd notations with ignored result is to generate additional shape inference constraints.
Both %cd and %op syntaxes support inline declarations of tensors. For %op these are differentiable, for %cd non-differentiable tensors.
A declaration site uses the record syntax. The key difference between the two extensions:
%op:{ tensor_name = init_expr }allows initialization expressions, or{ tensor_name }for default initialization (uniform random)%cd:{ tensor_name }requires self-referential syntax (the field name must match the field value identifier), no separate initialization expressions are allowed
Both syntaxes support additional record fields that map directly to labeled arguments of the tensor creation functions (see Tensor module signatures):
output_dimsor shorthando: specifies output dimensionsinput_dimsor shorthandi: specifies input dimensionsbatch_dimsor shorthandb: specifies batch dimensions- Any other labeled argument accepted by
TDSL.param(for%op) orNTDSL.term(for%cd)
Note: for the %op declarations, if the root operation comes from TDSL.O and is not qualified with a module name, it becomes qualified with PDSL which ensures that the created tensor will be differentiable (will have gradients), and will be able to take the additional argumetns. There are also special cases for literal constants to ensure the resulting tensor is initialized with these constants but is differentiable.
Examples:
%op:{ x = 5.0 },{ w; o = [hidden_dim] },{ weights = [1.0; 2.0] }%cd:{ temp },{ result; output_dims = [3; 4] },{ x; o = [10] }
The tensor name is bound to the newly created tensor, and the record expression itself evaluates to the tensor. The scope of the binding is the full scope of the extension point, even if the declaring record appeared in the body of a function that's inside the extension point scope (except for %op there is a special case of functions taking a unit parameter () discussed below -- inline definitions are introduced once () is applied). The first element of the label of the created tensor is the name that introduced it.
For %cd, inline declarations are allowed both in the assigned-to position (left-hand side) of assignments and in standalone tensor expressions. When used in assignments, one of the tensors on the right-hand-side is picked to provide additional label information if possible. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created. Inline declarations are still prohibited within the right-hand side of assignments to discourage over-use in locations with less label information. Example showing two tensor nodes declared inline, both of them include the label of the param p in their labels:
let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
[%cd
{ sgd_delta } =: p.grad + (!.weight_decay *. p);
if Float.(momentum > 0.0) then (
{ sgd_momentum } =: (!.momentum *. sgd_momentum) + sgd_delta;
if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
p =- learning_rate *. sgd_delta]Inline declarations can also be used outside of assignments for creating non-differentiable tensors, to mimic the behavior of %op but without the burden of initialization that a parameter would introduce:
let%cd mlp_result = mlp { point } in
let result_routine =
Train.to_routine (Context.context sgd_routine) IDX.empty
[%cd ~~("mlp infer"; mlp_result.forward)]
in
let callback (x, y) =
Tn.set_values point [| x; y |];
Train.run ctx result_routine;
Float.(mlp_result.@[0] >= 0.)
inFor %op, the declaration is allowed anywhere. If there is a unit () parameter in the function, the scope of inline-declared tensors is delimited at that parameter. The tensors are defined right after the unit parameter. If there is a labeled parameter with label label before the unit parameter (e.g., ~label), the inline-declared tensors will use that parameter (which should be of type string list) to enrich their labels. Example showing two param tensors declared inline, with scope delimited by () and labels enriched by the label parameter:
let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })To maintain the familiar concise syntax, yet allow for configurability during initialization, the %op syntax substitutes the operator function applied at the root of the initialization expression by prefixing the function identifier with PDSL (or by NTDSL when invoked from the %%extend_dsl syntax). Only unqualified identifiers get prefixed, and %oc is an escape hatch to prevent perfixing even for unqualified identifiers.
For einops users: If you're familiar with einops, see einops_comparison.md for a side-by-side mapping of einops operations to OCANNL's notation.
As we mentioned above, in the %cd syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the ~logic label. However, both the %cd and %op syntaxes support built-in operators that take an einsum specification: +* binding to NTDSL.einsum resp. TDSL.einsum, and ++ binding to NTDSL.einsum1 resp. TDSL.einsum1. +* is a "ternary" operator, binary wrt. tensor arguments, and ++ is a binary operator, unary postfix wrt. tensor arguments. There are even more einsum operators: binary @^+ and +++; unary @^^. When the einsum specification is a literal string, we support two syntax patterns: the string can either directly follow the operator (infix-style notation), or the string can follow the second argument (mixfix-style notation). When the spec string is an identifier, it must directly follow the operator.
+*, +++ and ++ use addition for the accumulation operation; @^+ and @^^ use maximum. You can verify that looking at the definitions of Operation.einsum, Operation.einsum1, etc. You can find examples of +* and ++ behavior in the test suite einsum_trivia.ml and in nn_blocks.ml. A frequent use-case for ++ is to sum out all axes of a tensor:
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
...where (!..) converts an integer into a constant tensor.
The specification syntax has two modes:
- if there is a comma anywhere in a spec, it is the multichar mode: axis identifiers are comma-separated and can have multiple characters;
- otherwise, it is the single-char mode: each alphanumeric character corresponds to an axis.
The syntax of a generalized einsum spec has two variants:
- unary: "<rhs> shape spec
=><lhs> shape spec", specifies a unary assignment<lhs> <asgn-op> <rhs>(see syntax for%cd), - binary: "<rhs1> shape spec
;<rhs2> shape spec=><lhs> shape spec", specifies a binary assignment<lhs> <asgn-op> <rhs1> <op> <rhs2>(see syntax for%cd).
Recall that a tensor shape is composed of three rows, i.e. sequences of axes: batch, input and output axes. Correspondingly, a shape spec in the notation can be:
- the output row at the end of the spec, or just the output row,
- the input row to the left of
->, if given, - the batch row to the left of
|, if given.
The notation for a row is composed of sequences of row specs, and an optional row variable spec. A row variable tracks broadcasting. The syntax of a row:
- a sequence of axis specs: specifies the rightmost axes, with untracked broadcasting "to the left",
- a row variable spec followed a sequence of axis specs for the rightmost axes,
- leftmost axes specs, followed by a row variable, followed by rightmost axes specs.
The syntax of a row variable:
..variable_id..: variable_id stands for the row variable identifier,- ellipsis
...is context dependent: in the batch row it means..batch.., in the input row..input.., in the output row..output...
The syntax of an axis spec:
- Depending on the mode, either a alphabetic character or an alphanumeric identifier provides an axis variable.
- Dhe underscore
_is a placeholder to align other axes, but does not specify anything for the given axis (it is not a variable). - A number specifies the particular dimension within the axis,
- A
+sign specifies a convolution input axis with the output on the left of+and the kernel on the right of+.- In both the output part and the kernel part you can prefix the axis variable by a constant coefficient with the
*sign. - The coefficient can directly only be an integer, e.g.
"2*i+3*k", but under the%opand%cdsyntax extensions, it can also be an identifier of an integer value, e.g.let stride = 2 and dilation = 3 in [%op "input" +* "stride * a + dilation * b; b=>a," "kernel"]. - Note the comma above. The syntax extension's expansion of stride and dilation respects the "multichar" mode. Without the comma we are limited to single-character identifiers, e.g.
let s = 2 and d = 3 in [%op "input" +* "is*a+d*bc;b=>iac" "kernel"].
- In both the output part and the kernel part you can prefix the axis variable by a constant coefficient with the
- The use_padding modifier before
+specifies whether padding is used:=+for padded convolution (use_padding=true):stride*output=+dilation*kernel<+for valid convolution (use_padding=false):stride*output<+dilation*kernel- Plain
+(unspecified): reads theuse_paddingvariable from scope (only under%opand%cdsyntax extensions)
Examples:
...|...->... => 0: reduce all axes of the argument into a single number. Useful e.g. for reducing losses to a single number....|... => 0,...->... => 0,... => 0do the same but will fail if the argument has axes of the kind for which the ellipsis is missing....|...->... => ...|...->...: fully pointwise unary operation....->... => ...->...,...|... => ...|...,... => ...: fully pointwise but will fail if the argument has axes of the kind for which the ellipsis is missing....|...->... ; ...|...->... => ...|...->...: fully pointwise binary operation....|...->... => ...->...: reduce the batch axes into the result.2...|...->... => ...|...->...: slice the tensor at dimension 2 of the leftmost batch axis. Note that the tensor operation@|implements slicing at the leftmost batch axis for arbitrary dimension....|... => ...|...2: expand the tensor by putting the argument at leftmost output dimension 2 of the result (and reduce input axes if any).rhs ++ "...|... => ...|...2"will fill the other cells of the new tensor with zeroes;[%cd lhs =:* rhs ~logic:"...|... => ...|...2"]will fill the other cells oflhswith ones since it's the neutral element of the assignment (reduction) operator, here with ones.ijk => kji: reverse the three output axes, fails if the argument has any other axes.ijk => ki: as above but also reduce the second-leftmost output axis...v..|...ijk => ..v..kji: reverse the three rightmost output axes, reduce any other output axes, pointwise for batch axes, pairing the batch axes with the leftmost output axes of the result. Fails if the argument has input axes.2..v..|... => ..v..: slice the tensor at dimension 2 of the leftmost batch axis, reduce all its output axes, preserve its other batch axes as output axes. Fails if the argument has input axes.
The affine axis syntax enables convolution and pooling operations directly in einsum notation. The semantics:
- Input index formula:
input_index = stride * output_position + dilation * kernel_position - Padded convolution (
=+): Input and output dimensions are equal (padding compensates for kernel extent) - Valid convolution (
<+): No padding. The dimension relationship is:whereinput_size = stride * (output_size - 1) + effective_kernel_spaneffective_kernel_span = 1 + (kernel_size - 1) * dilation.
Important constraint for valid convolution: The formula must hold exactly, meaning (input_size - effective_kernel_span) must be divisible by stride. Otherwise, shape inference will fail with "incompatible stride" error.
General rule (use_padding = false case): (input_size - effective_kernel_span) mod stride = 0
For example, with stride=2, kernel_size=2, dilation=1:
effective_kernel_span = 1 + (2-1) * 1 = 2- A 4x4 input gives output_size:
4 = 2 * (output - 1) + 2→output = 2✓ - A 5x5 input would fail:
(5 - 2) mod 2 = 1 ≠ 0→ shape inference error
With stride=2, kernel_size=3, dilation=1:
effective_kernel_span = 1 + (3-1) * 1 = 3- A 9x9 input works:
(9 - 3) mod 2 = 0→output = 4✓ - A 10x10 input fails:
(10 - 3) mod 2 = 1 ≠ 0→ shape inference error
Examples:
-
Max pooling 2x2 with stride 2:
input @^+ "...|2*oh<+wh, 2*ow<+ww, ..c..; wh, ww => ...|oh, ow, ..c.." window- Uses
@^+(max-reduce) to take maximum over the kernel window 2*oh<+whmeans: for each output positionoh, access input at2*oh + kernel_offset- Valid convolution (
<+) so no padding; output is half the input size
- Uses
-
2D convolution with stride 1:
input +* "...|oh<+wh, ow<+ww, ..ic..; wh, ww, ic => ...|oh, ow, ..oc.." kernel- Sum-reduces over kernel height, kernel width, and input channels
- Output channels come from the output shape (typically inferred for the kernel)
The ^ operator in einsum specifications creates a concatenated axis from multiple components. This enables:
- Tensor concatenation: Combine tensors along an axis
- Block tensor construction: Build structured tensors from components
- Axis slicing: Extract or assign to parts of an axis
The syntax a^b (or a^b^c etc.) creates a single axis of iteration that first iterates over component a, then over component b, etc. The components are axis labels (identifiers in multi-char mode, single characters in single-char mode).
Examples of concatenation patterns (using vector notation for simplicity):
| Pattern | Description |
|---|---|
a; b => a^b |
Concatenate vectors: result contains all of a then all of b |
a^b => a |
Extract prefix: take the first part of a vector |
a^b => b |
Extract suffix: take the last part of a vector |
a => a^b |
Replace prefix: assign to first part, leaving suffix unchanged |
b => a^b |
Replace suffix: assign to last part, leaving prefix unchanged |
a^b^c => b |
Extract middle: requires knowing sizes of a and c |
b => a^b^c |
Replace middle: requires knowing sizes of a and c |
Shape inference behavior: When the argument and result shapes are both known, prefix and suffix operations (a^b => a, a^b => b, a => a^b, b => a^b) don't need additional dimension information. Middle operations (a^b^c => b, b => a^b^c) require providing dimension constraints for the unmatched components.
Integer constants in concatenation: When used with ^, an integer specifies the size of that axis component rather than indexing into a fixed dimension. For example:
3^a => a: Skip 3 elements at the beginning of the inputa => a^3: Assign to all but the last 3 elements of the result3^a^5 => a: Extract middle portion, skipping 3 at start and 5 at end
These integer-sized components become fresh internal symbols during projection derivation, causing those components to be skipped.
N-ary einsum for multiple tensors: The einsum parser supports any number of RHS tensors separated by semicolons. This enables operations like:
(* Concatenate 3 output-axes-only tensors along the first axis *)
(a, b, c) ++^ "x, ...; y, ...; z, ... => x^y^z, ..."Syntax for %op vs %cd: In the %op extension, use the ++^ operator for concatenation
and block specs. For %cd, single-argument concatenation (slicing, partial updates) works with the
existing ~logic:"..." syntax:
(* Assign to prefix of target, leaving suffix unchanged *)
let%cd update_prefix ~target ~source =
target =: source ~logic:"a => a^b"Multi-argument syntax for %cd (needed for tensor concatenation with multiple sources) is still being designed. The natural choice [rhs1; rhs2] conflicts with the planned block tensor syntax.
The tensor literal syntax will be generalized to support block tensor construction, where tensor literals become a special case with scalar components. The syntax extensions recursively compose argument tensors by introducing and concatenating along a new leading axis:
| Syntax | Axis kind | Example |
|---|---|---|
( , ) |
Input | (ta, tb) concatenates along a new input axis |
[ ; ] |
Output | [ta; tb] concatenates along a new output axis |
| `[ | ; | ]` |
For example, [ta; tb] translates to:
- Introduce a new output axis for each component:
...|...->... => ...|...->0, ...fortaandtb - Concatenate along that axis:
...|...->a,... ; ...|...->b,... => ...|...->a^b,...
This allows constructing block matrices, block tensors, and other structured tensors from smaller components.
The syntaxes +*, ++, and ++^ accept an optional list of strings argument after the
specification string. When passed, the strings should be some of the identifiers used in the
specification. Both dimension variable and row variable labels are supported. This will introduce
bindings for Indexing.variable_ref objects at the same point as the inline parameter definition
bindings, and will pass these objects with the ~capture_dims argument to einsum, einsum1,
resp. concat. The bound objects can later be used with Operation.embed_dim or its alias
Operation.TDSL.O.dim to embed the solved dimension of the corresponding variable (as a number)
into a tensor expression. For a row variable, the number will be the product of the dimensions it
resolved into.
The %cd syntax uses record-style notation to point to:
- the value tensor node of a tensor
<tensor>.value, - the gradient tensor node of a tensor
<tensor>.grad, - the merge buffer of a tensor node
<tensor-node>.merge;<tensor>.mergeis a shorthand for<tensor>.value.merge, - the forward code of a tensor
<tensor>.forward, - the backprop code of a tensor
<tensor>.backprop, - the zeroing gradients code of a tensor
<tensor>.zero_grads.
The accessor .value can (almost?) always be dropped: by default, tensors in the %cd syntax refer to their value nodes. The forward and backprop code accesses manage roots (via the Tensor.consume_forward_code and Tensor.consume_backprop_code functions).
For example, in a data-parallel computation, gradients of the same param p can be merged across devices using the code p.grad =+ p.grad.merge, combined with an explicit device-to-device transfer.
The %cd syntax uses the prefix operator (~~) in a semicolon sequence to introduce block comments:
type Assignments.t =
...
| Block_comment of string * t
...Schematic example: ~~("space" "separated" "comment" "tensor p debug_name:" p; <scope of the comment>). The content of the comment uses application syntax, must be composed of strings, <tensor>, <tensor>.value (equivalent to <tensor>), <tensor>.grad components, where <tensor> is any tensor expression or tensor identifier.
This syntax used to be very important, because comments in assignments are used to derive file names for generated code. Now, the %cd syntax automatically introduces block comments for code at let-binding points, using the identifier. Currently the comment does not yet incorporate any tensor node labels -- and for that reason we are not yet adding comments around function bodies if a function is annotated with %cd. Moreover, we only automatically add comments for code, not for tensors -- so the ~~ syntax is still helpful when the comment needs to be more precise for debugging or naming purposes, or when %cd is not used with a let binding, or when we want to pass a forward code directly instead of let-binding it. If an explicit comment is provided at the let-binding level, the automatic one is omitted.
When an extension point is applied to a let-binding, e.g. let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }), it uses the name of the binding (mlp_layer in the example) for the label of the primary tensor created by the extension, if any. This is why the resulting layer tensor in the example has its label starting with "mlp_layer". If the extension is over a semicolon-separated sequence of expressions, the primary tensor can only be in the last component of the sequence, other syntax constructs are handled analogously.
The example let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }) also illustrates providing additional string list to populate the label of the tensor: label must be of type string list.
The resulting (primary) tensor's label will also have incorporated the label of the input argument, if any. In our example, the resulting mlp_layer tensor will also include the label of the actually applied x. If the function has a unit parameter (), like mlp_layer above, only parameters to the right of () are considered for label extraction.
When there is the unit parameter, and a ~label parameter (specifically a parameter with label label), this label is also incorporated.
In the %op syntax, inline declarations use record syntax with additional fields to configure the tensor:
- Basic declaration with default initialization:
{ tensor_name }uses OCaml's punning syntax and defaults to uniform random initialization - Declaration with value initialization:
{ tensor_name = value }where value can be:- A scalar:
{ x = 5.0 }or{ y = 42 } - A list/array:
{ weights = [1.0; 2.0; 3.0] } - An initialization function:
{ z = uniform () }
- A scalar:
- Declaration with dimensions: Additional fields specify tensor dimensions:
output_dimsor shorthando:{ b; output_dims = [ hid_dim ] }or{ b; o = [ hid_dim ] }input_dimsor shorthandi:{ w; i = [ 3 ]; o = [ 4 ] }batch_dimsor shorthandb: for batch dimensions (rarely used in%op)
A very simple example from micrograd_demo: Micrograd README basic example:
let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
...How does it relate to let%op c = { a = -4 } + { b = 2 } in ...? Without brackets, the number is used to initialize all cells of the tensor value, and shape inference decides the shape of the tensor. With brackets, the bracketing specifies both all the cells and the exact shape of the tensor.
If you recall, inline declared param tensors get lifted out of functions to be defined at the point of a unit () parameter. Our example let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }) translates as:
let mlp_layer ~label ~hid_dim () =
let w = TDSL.param ~more_label:label "w" ()
and b = TDSL.param ~more_label:label ~output_dims:[ hid_dim ] "b" () in
fun x -> TDSL.O.(relu (w * x + b))For this to work properly, when employing such network blocks, their params also need to be introduced at the right moment. At one point, we tried to do this automatically by the %op syntax, but that was confusing to use. So you need to ensure scoping manually. Consider:
(* FIXME: this is wrong! Doesn't bind the parameters at the right place. *)
let%op three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () x =
mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
(mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
(mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () x))This example would work if we used direct inline definitions, but it does not work when the definitions are indirectly in the functions called. We need to write instead:
let three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () =
let layer3 = mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
and layer2 = mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
and layer1 = mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () in
fun x -> layer3 (layer2 (layer1 x))The manual approach naturally extends to programmatic network architectures:
let mlp ~label ~hid_dims () =
let layers =
List.mapi hid_dims ~f:(fun i hid_dim ->
mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim ())
in
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)This syntax extension creates a module DSL_modules with the same submodules as Operation.DSL_modules. It removes the boilerplate associated with introducing new operators into the modules TDSL, NTDSL, PDSL and their O submodules. The payload (i.e. content) of %%extend_dsls must be non-recursive let-bindings. They are parsed using a slight variant of the %op syntax, and are inserted into the DSL modules. The identifiers of the root operator functions of the definitions, if unqualified, are prefixed with the appropriate module, similarly to the behavior of inline definitions. Another unique feature of %%extend_dsls parsing is that inline tensor definitions, like in %cd, do not introduce gradients for the tensors, but, like %op, they do introduce initialization for the inline-defined tensors.
The DSL modules expose the value grad_spec that can be useful for defining operators via a "scheme" function. See the example using the box_muller helper at the beginning of lib/nn_blocks.ml. The definitions there use the %oc escape extension to avoid the prefixing mentioned above.
OCANNL has a built-in numerical binary operation to-power-of: Ops.ToPowOf. As part of assignments, the corresponding operator is **. Here is the full definition of the to-power-of tensor operation from Operation:
let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =
let module NTDSL = struct
include Initial_NTDSL
module O = struct
include NDO_without_pow
let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
end
end in
let p_t = NTDSL.number p in
let%cd op_asn ~t ~t1 ~t2 ~projections = v =: v1 ** v2 ~projections in
let%cd grad_asn =
if Tensor.is_prohibit_grad grad_spec then fun ~v:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ -> Asgns.Noop
else if Float.equal p 2.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. t1 * g
else if Float.equal p 1.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ g
else fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. (t1 **. (p -. 1.)) * g
in
Tensor.binop ~label:("**." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 p_tOn the Tensor level, this is implemented as a binary tensor operation, but it is exposed as a unary tensor operation! To avoid the complexities of propagating gradient into the exponent, Operation.pointpow is implemented as a function of only one tensor, the exponent is a number. We hard-code the pointwise-power-of operator NTDSL.O.( **. ), resp. TDSL.O.( **. ), in the %cd and %op syntaxes, to pass the numeric value to pointpow (the second argument of **.) without converting it to a tensor first.
The syntax %cd translator needs to accomplish more than a context-free conversion of a concise notation to an Assignments.comp data-type. In particular:
- It needs to keep track if
~projectionsis in scope, and it needs to collect the information about an assignment to properly transofm the projections from the scope into the projections valid for the particular assignment. - Whenever the parsed notation uses tensors whose value nodes have not been computed yet, the translator needs to include the "forward" code of the tensors among the generated assignments. Typically this is required for embedded tensor expressions, which create new tensors. The translator puts the forward code in sequence just prior to the assignment that made use of the created tensor. The translator includes the forward code of tensors that are "forward roots" at the time the assigments are constructed (using
Tensor.is_fwd_root). - For inline declarations of tensors, the translator needs to pick the right other tensor, if any, to enrich the label information of the created tensor. Mechanisms:
- Prefer tensors from identifiers (or field dereferences), since labels of tensor expressions (creating new tensors) will typically be overly verbose.
- Filter out escaping variables (identifiers coming from nested function parameters).
- Filter out embedded tensor expressions. In principle we could use them -- we already introduce local bindings to avoid recomputing the expressions -- but this would need pulling out these bindings together with the inline definition and does not seem worth the benefit.
- When one inline declaration uses another inline declaration on its right-hand-side, recall the other declaration's label-enriching-tensor and use it directly.
- The argument slots in
Assignments.Accum_binopandAssignments.Accum_unopcan be either regular tensor nodes, or merge buffers of tensor nodes. The translator needs to determine that. - When a tensor expression is used to create a new tensor, the translator lifts the expression into a let-binding, to be able to refer to the (same) tensor more than once. The created tensor is referred to at least twice: at its use site, and to include its forward code among the assignments.
In fact, the syntax %cd produces Assignments.comp values:
type comp = {
asgns : t;
embedded_nodes : Set.M(Tnode).t;
}The tensor nodes that are in asgns but not in embedded_nodes, and are on-device, must already be present in contexts with which the computation is linked. Such non-embedded nodes can be seen as inputs to the computation -- except that for backprop code of a tensor, they are actually the outputs! Embedded nodes are closely related to rootness -- when a node has not been used in the code of another tensor, it is a root (a forward root for value nodes and a backprop root for grad nodes). embedded_nodes were roots the first time they were used in asgns. Parameters, as created by Tensor.param, are not embedded in the code that uses them and thus will not be in embedded_nodes of the forward and backprop code over the parameters; however, they will constitute the embedded_nodes of the Tensor.init_params code.