You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add syntactic sugar for constructing block tensors from component tensors
using OCaml list/tuple/array notation inside %op blocks. Lists stack along
a new leading output axis, tuples along input axis (top-level only), and
arrays along batch axis.
Implementation: first-leaf disambiguation (numeric → ndarray constant,
non-numeric → block tensor), two-step unsqueeze via einsum1 + concat.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Multi-argument syntax for `%cd` (needed for tensor concatenation with multiple sources) is still being designed. The natural choice `[rhs1; rhs2]` conflicts with the planned block tensor syntax.
551
551
552
-
### Block tensor syntax (upcoming)
552
+
### Block tensor syntax
553
553
554
-
The tensor literal syntax will be generalized to support block tensor construction, where tensor literals become a special case with scalar components. The syntax extensions recursively compose argument tensors by introducing and concatenating along a new leading axis:
554
+
The tensor literal syntax is generalized to support block tensor construction: when a list, tuple, or
555
+
array literal inside `%op` contains tensor expressions (rather than numeric literals), it desugars
556
+
into `einsum1` (unsqueeze) + `concat` calls that stack the components along a new leading axis.
555
557
556
-
| Syntax | Axis kind | Example |
557
-
|--------|-----------|---------|
558
-
|`( , )`| Input |`(ta, tb)` concatenates along a new input axis |
559
-
|`[ ; ]`| Output |`[ta; tb]` concatenates along a new output axis |
560
-
| `[| ; |]` | Batch | `[|ta; tb|]` concatenates along a new batch axis |
558
+
**Axis mapping by syntax form:**
561
559
562
-
For example, `[ta; tb]`translates to:
563
-
1. Introduce a new output axis for each component: `...|...->... => ...|...->0, ...` for `ta` and `tb`
564
-
2. Concatenate along that axis: `...|...->a,... ; ...|...->b,... => ...|...->a^b,...`
560
+
-`[ta; tb; ...]`-- stacks along a new leading **output** axis
561
+
-`(ta, tb, ...)` -- stacks along a new leading **input** axis (top-level `%op` expressions only)
562
+
-`[|ta; tb; ...|]` -- stacks along a new leading **batch** axis
565
563
566
-
This allows constructing block matrices, block tensors, and other structured tensors from smaller components.
564
+
**Disambiguation rule:** if the first leaf of the nested literal is a numeric constant (int or
565
+
float), the expression is treated as an ndarray constant (existing behavior). Otherwise it is a block
566
+
tensor. To include scalar constants alongside tensors, wrap them with `!.` or `!..`
567
+
(e.g., `[!.1.0; ta]`). Computed-number expressions like `Float.sin 1.0` at the first leaf position
568
+
will trigger block tensor interpretation.
569
+
570
+
**Examples:**
571
+
572
+
```ocaml
573
+
(* Stack two vectors along a new leading output axis *)
574
+
let%op stacked = [v1; v2; v3]
575
+
(* Shape: if each vi has output [d], result has output [3, d] *)
576
+
577
+
(* 2x2 block matrix from scalars -- nesting currently requires let bindings *)
578
+
let%op row1 = [a; b] in
579
+
let%op row2 = [c; d] in
580
+
let%op mat = (row1, row2) ++^ "a; b => a^b"
581
+
582
+
(* Batch two samples *)
583
+
let%op batched = [|sample1; sample2|]
584
+
585
+
(* Stack along input axis (top-level only) *)
586
+
let%op input_block = (x, y)
587
+
```
588
+
589
+
**Notes:**
590
+
591
+
- All components must have the same trailing shape (stacking requires shape compatibility).
592
+
- Single-element block tensors like `[ta]` act as unsqueeze (add a size-1 leading axis).
593
+
- Tuple syntax `(ta, tb)` only works at the top level of a `%op` expression. Tuples inside function
594
+
arguments (e.g., `f (ta, tb)`) are preserved as regular OCaml tuples.
595
+
- Direct nesting like `[[ta; tb]; [tc; td]]` is currently limited by shape inference; use
596
+
intermediate let bindings or explicit `++^` for nested block construction.
567
597
568
598
### Capturing the dimensions of selected axes for further computation or to add shape constraints
0 commit comments