Skip to content

Commit 142e941

Browse files
lukstaficlaude
andcommitted
feat: block tensor literal syntax [ta; tb], (ta, tb), [|ta; tb|] in %op
Add syntactic sugar for constructing block tensors from component tensors using OCaml list/tuple/array notation inside %op blocks. Lists stack along a new leading output axis, tuples along input axis (top-level only), and arrays along batch axis. Implementation: first-leaf disambiguation (numeric → ndarray constant, non-numeric → block tensor), two-step unsqueeze via einsum1 + concat. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 62d32a2 commit 142e941

6 files changed

Lines changed: 430 additions & 18 deletions

File tree

docs/syntax_extensions.md

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -549,21 +549,51 @@ let%cd update_prefix ~target ~source =
549549

550550
Multi-argument syntax for `%cd` (needed for tensor concatenation with multiple sources) is still being designed. The natural choice `[rhs1; rhs2]` conflicts with the planned block tensor syntax.
551551

552-
### Block tensor syntax (upcoming)
552+
### Block tensor syntax
553553

554-
The tensor literal syntax will be generalized to support block tensor construction, where tensor literals become a special case with scalar components. The syntax extensions recursively compose argument tensors by introducing and concatenating along a new leading axis:
554+
The tensor literal syntax is generalized to support block tensor construction: when a list, tuple, or
555+
array literal inside `%op` contains tensor expressions (rather than numeric literals), it desugars
556+
into `einsum1` (unsqueeze) + `concat` calls that stack the components along a new leading axis.
555557

556-
| Syntax | Axis kind | Example |
557-
|--------|-----------|---------|
558-
| `( , )` | Input | `(ta, tb)` concatenates along a new input axis |
559-
| `[ ; ]` | Output | `[ta; tb]` concatenates along a new output axis |
560-
| `[| ; |]` | Batch | `[|ta; tb|]` concatenates along a new batch axis |
558+
**Axis mapping by syntax form:**
561559

562-
For example, `[ta; tb]` translates to:
563-
1. Introduce a new output axis for each component: `...|...->... => ...|...->0, ...` for `ta` and `tb`
564-
2. Concatenate along that axis: `...|...->a,... ; ...|...->b,... => ...|...->a^b,...`
560+
- `[ta; tb; ...]` -- stacks along a new leading **output** axis
561+
- `(ta, tb, ...)` -- stacks along a new leading **input** axis (top-level `%op` expressions only)
562+
- `[|ta; tb; ...|]` -- stacks along a new leading **batch** axis
565563

566-
This allows constructing block matrices, block tensors, and other structured tensors from smaller components.
564+
**Disambiguation rule:** if the first leaf of the nested literal is a numeric constant (int or
565+
float), the expression is treated as an ndarray constant (existing behavior). Otherwise it is a block
566+
tensor. To include scalar constants alongside tensors, wrap them with `!.` or `!..`
567+
(e.g., `[!.1.0; ta]`). Computed-number expressions like `Float.sin 1.0` at the first leaf position
568+
will trigger block tensor interpretation.
569+
570+
**Examples:**
571+
572+
```ocaml
573+
(* Stack two vectors along a new leading output axis *)
574+
let%op stacked = [v1; v2; v3]
575+
(* Shape: if each vi has output [d], result has output [3, d] *)
576+
577+
(* 2x2 block matrix from scalars -- nesting currently requires let bindings *)
578+
let%op row1 = [a; b] in
579+
let%op row2 = [c; d] in
580+
let%op mat = (row1, row2) ++^ "a; b => a^b"
581+
582+
(* Batch two samples *)
583+
let%op batched = [|sample1; sample2|]
584+
585+
(* Stack along input axis (top-level only) *)
586+
let%op input_block = (x, y)
587+
```
588+
589+
**Notes:**
590+
591+
- All components must have the same trailing shape (stacking requires shape compatibility).
592+
- Single-element block tensors like `[ta]` act as unsqueeze (add a size-1 leading axis).
593+
- Tuple syntax `(ta, tb)` only works at the top level of a `%op` expression. Tuples inside function
594+
arguments (e.g., `f (ta, tb)`) are preserved as regular OCaml tuples.
595+
- Direct nesting like `[[ta; tb]; [tc; td]]` is currently limited by shape inference; use
596+
intermediate let bindings or explicit `++^` for nested block construction.
567597

568598
### Capturing the dimensions of selected axes for further computation or to add shape constraints
569599

tensor/ppx_op.ml

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,55 @@ let make_vb_nd ~no_grad ~opt_label ~init_nd ~extra_args ~loc name =
7070
let vb = Ast_helper.Vb.mk ~loc pat v in
7171
(pat, vb)
7272

73+
let rec is_ndarray_constant_expr expr =
74+
match expr.pexp_desc with
75+
| Pexp_constant (Pconst_float _) | Pexp_constant (Pconst_integer _) -> true
76+
| Pexp_tuple (e :: _) | Pexp_array (e :: _) -> is_ndarray_constant_expr e
77+
| Pexp_construct ({ txt = Lident "::"; _ }, _) ->
78+
let elems = collect_list [] expr in
79+
(match elems with e :: _ -> is_ndarray_constant_expr e | [] -> true)
80+
| Pexp_construct ({ txt = Lident "[]"; _ }, _) -> true
81+
| Pexp_array [] -> true
82+
| _ -> false
83+
84+
let translate_block_tensor ~loc ~loop ~label ~opt_label:_ axis_kind elems =
85+
match elems with
86+
| [] ->
87+
( no_vbs,
88+
Ast_builder.Default.pexp_extension ~loc
89+
@@ Location.error_extensionf ~loc
90+
"ppx_ocannl %%op: block tensor requires at least one component" )
91+
| _ ->
92+
let vbss, translated = List.unzip (List.map elems ~f:loop) in
93+
let unsqueeze_spec_str =
94+
match axis_kind with
95+
| `Output -> "...|...->... => ...|...->0,..."
96+
| `Input -> "...|...->... => ...|0,...->..."
97+
| `Batch -> "...|...->... => 0,...|...->..."
98+
in
99+
let unsqueeze_spec = substitute_identifiers_in_einsum_spec ~loc unsqueeze_spec_str in
100+
let unsqueezed =
101+
List.map translated ~f:(fun e -> [%expr einsum1 [%e unsqueeze_spec] [%e e]])
102+
in
103+
let labels = List.mapi elems ~f:(fun i _ -> "bt" ^ Int.to_string i) in
104+
let concat_parts = String.concat ~sep:"^" labels in
105+
let concat_spec_str =
106+
match axis_kind with
107+
| `Output ->
108+
String.concat ~sep:"; " (List.map labels ~f:(fun l -> l ^ ",..."))
109+
^ " => " ^ concat_parts ^ ",..."
110+
| `Input ->
111+
String.concat ~sep:"; " (List.map labels ~f:(fun l -> "...|" ^ l ^ ",...->..."))
112+
^ " => ...|" ^ concat_parts ^ ",...->..."
113+
| `Batch ->
114+
String.concat ~sep:"; " (List.map labels ~f:(fun l -> l ^ ",...|...->..."))
115+
^ " => " ^ concat_parts ^ ",...|...->..."
116+
in
117+
let concat_spec = substitute_identifiers_in_einsum_spec ~loc concat_spec_str in
118+
let rhses_array = Ast_builder.Default.pexp_array ~loc unsqueezed in
119+
( reduce_vbss vbss,
120+
[%expr concat ?label:[%e opt_expr ~loc label] [%e concat_spec] [%e rhses_array]] )
121+
73122
let rec translate ~no_grads_for_inline_defs ~num_configs ~is_toplevel ~opt_label ?label expr =
74123
let loc = expr.pexp_loc in
75124
let loop = translate ~no_grads_for_inline_defs ~num_configs ~is_toplevel:false ~opt_label in
@@ -350,9 +399,20 @@ let rec translate ~no_grads_for_inline_defs ~num_configs ~is_toplevel ~opt_label
350399
Ast_builder.Default.pexp_extension ~loc
351400
@@ Location.error_extensionf ~loc
352401
"ppx_ocannl %%op: record field label must be a simple identifier" ))
353-
| { pexp_desc = Pexp_array _; _ }
354-
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
355-
(no_vbs, ndarray_op ?label ~ndarray_fn:[%expr TDSL.ndarray] expr)
402+
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } as list_expr ->
403+
if is_ndarray_constant_expr list_expr then
404+
(no_vbs, ndarray_op ?label ~ndarray_fn:[%expr TDSL.ndarray] list_expr)
405+
else
406+
let elems = collect_list [] list_expr in
407+
translate_block_tensor ~loc ~loop ~label ~opt_label `Output elems
408+
| { pexp_desc = Pexp_array _; _ } ->
409+
if is_ndarray_constant_expr expr then
410+
(no_vbs, ndarray_op ?label ~ndarray_fn:[%expr TDSL.ndarray] expr)
411+
else
412+
let elems =
413+
match expr.pexp_desc with Pexp_array elems -> elems | _ -> assert false
414+
in
415+
translate_block_tensor ~loc ~loop ~label ~opt_label `Batch elems
356416
| [%expr !.[%e? expr1]] ->
357417
(* Hardcoding the patterns for (!.), (!..), and ( **. ) to avoid treating the constants as
358418
already tensors. *)
@@ -364,6 +424,10 @@ let rec translate ~no_grads_for_inline_defs ~num_configs ~is_toplevel ~opt_label
364424
| [%expr [%e? expr1] **. [%e? expr2]] ->
365425
let vbs, e1 = loop expr1 in
366426
(vbs, [%expr TDSL.O.( **. ) ?label:[%e opt_expr ~loc label] [%e e1] [%e expr2]])
427+
| { pexp_desc = Pexp_tuple elems; _ } when is_toplevel && List.length elems >= 2 ->
428+
if is_ndarray_constant_expr expr then
429+
(no_vbs, ndarray_op ?label ~ndarray_fn:[%expr TDSL.ndarray] expr)
430+
else translate_block_tensor ~loc ~loop ~label ~opt_label `Input elems
367431
| [%expr
368432
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }] ([%e? expr2], [%e? expr3])]
369433
when Hashtbl.mem binary_ops op_ident ->
@@ -408,10 +472,16 @@ let rec translate ~no_grads_for_inline_defs ~num_configs ~is_toplevel ~opt_label
408472
| Some unit_pos when i < unit_pos ->
409473
(* Before unit: preserve as OCaml expression *)
410474
(no_vbs, (arg_label, arg_expr))
411-
| _ ->
475+
| _ -> (
412476
(* After unit or no unit: transform *)
413-
let vbs, e = loop arg_expr in
414-
(vbs, (arg_label, e)))
477+
match arg_expr.pexp_desc with
478+
| Pexp_tuple _ ->
479+
(* Preserve tuple arguments to avoid block tensor misinterpretation.
480+
Matches current behavior: tuples fell through catch-all. *)
481+
(no_vbs, (arg_label, arg_expr))
482+
| _ ->
483+
let vbs, e = loop arg_expr in
484+
(vbs, (arg_label, e))))
415485
in
416486
let all_vbs = reduce_vbss (vbs_fn :: vbs_args) in
417487
(all_vbs, Ast_builder.Default.pexp_apply ~loc e_fn processed_args)

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,17 @@
322322
(preprocess
323323
(pps ppx_here ppx_ocannl)))
324324

325+
(test
326+
(name test_block_tensor)
327+
(package neural_nets_lib)
328+
(deps
329+
ocannl_config
330+
(env_var OCANNL_BACKEND))
331+
(modules test_block_tensor)
332+
(libraries base ocannl stdio)
333+
(preprocess
334+
(pps ppx_here ppx_ocannl)))
335+
325336
(test
326337
(name test_random_histograms)
327338
(package neural_nets_lib)

test/operations/rope_test.expected

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
23
=== Test 1: Deinterleave roundtrip ===
34
Original: 1 2 3 4 5 6
45
Roundtrip: 1 2 3 4 5 6
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
=== Block Tensor Literal Tests ===
4+
5+
--- Test 1: List output axis [x1; x2] ---
6+
HERE: test/operations/test_block_tensor.ml:29:21
7+
[8]: ++^_stacked shape 0:2,1:3 [
8+
[ 1.00 ; 2.00 ; 3.00 ]
9+
; [ 4.00 ; 5.00 ; 6.00 ]
10+
]
11+
12+
--- Test 2: Array batch axis [|x1; x2|] ---
13+
HERE: test/operations/test_block_tensor.ml:42:21
14+
[18]: ++^_batched shape 0:2|1:2 [|
15+
[ 10.00 ; 20.00 ]
16+
; [ 30.00 ; 40.00 ]
17+
|]
18+
19+
--- Test 3: Tuple input axis (x1, x2) ---
20+
HERE: test/operations/test_block_tensor.ml:49:21
21+
[24]: ++^_input_stack shape 1:2->0:3 [
22+
1.00 , 4.00
23+
; 2.00 , 5.00
24+
; 3.00 , 6.00
25+
]
26+
27+
--- Test 4: 3-way list [x1; x2; x3] ---
28+
HERE: test/operations/test_block_tensor.ml:59:21
29+
[34]: ++^_triple shape 0:3,1:3 [
30+
[ 1.00 ; 2.00 ; 3.00 ]
31+
; [ 4.00 ; 5.00 ; 6.00 ]
32+
; [ 7.00 ; 8.00 ; 9.00 ]
33+
]
34+
35+
--- Test 5: Scalars in block tensor ---
36+
HERE: test/operations/test_block_tensor.ml:73:21
37+
[48]: ++^_scalar_stack shape 0:3 [ 1.00 ; 2.00 ; 3.00 ]
38+
39+
--- Test 6: Single element [x1] ---
40+
HERE: test/operations/test_block_tensor.ml:80:21
41+
[52]: ++^_unsqueezed shape 0:1,1:3 [ [ 1.00 ; 2.00 ; 3.00 ] ]
42+
43+
--- Test 7: Gradient flow (2-way) ---
44+
grad_result (sin of stacked):
45+
HERE: test/operations/test_block_tensor.ml:99:21
46+
┌────────────────────────────────────┐
47+
│[64]: sin_grad_result shape 0:2,1:2 │
48+
│┌──────┬───────────────────┐ │
49+
││ │axis 1 │ │
50+
│├──────┼───────────────────┤ │
51+
││axis 0│ 8.41e-1 9.09e-1 │ │
52+
││ │ 1.41e-1 -7.56e-1 │ │
53+
│└──────┴───────────────────┘ │
54+
└────────────────────────────────────┘
55+
56+
Gradient of g1 (should be cos of original):
57+
HERE: test/operations/test_block_tensor.ml:101:21
58+
┌────────────────────┐
59+
│[54]: 1,2 shape 0:2 │
60+
│┌┬────────────┐ │
61+
│││axis 0 │ │
62+
│├┼────────────┤ │
63+
│││ 1.00 2.00 │ │
64+
│└┴────────────┘ │
65+
└────────────────────┘
66+
┌─────────────────────────────┐
67+
│[54]: 1,2 shape 0:2 grad_1,2│
68+
│┌┬───────────────────┐ │
69+
│││axis 0 │ │
70+
│├┼───────────────────┤ │
71+
│││ 5.40e-1 -4.16e-1 │ │
72+
│└┴───────────────────┘ │
73+
└─────────────────────────────┘
74+
75+
Gradient of g2:
76+
HERE: test/operations/test_block_tensor.ml:103:21
77+
┌────────────────────┐
78+
│[56]: 3,4 shape 0:2 │
79+
│┌┬────────────┐ │
80+
│││axis 0 │ │
81+
│├┼────────────┤ │
82+
│││ 3.00 4.00 │ │
83+
│└┴────────────┘ │
84+
└────────────────────┘
85+
┌─────────────────────────────┐
86+
│[56]: 3,4 shape 0:2 grad_3,4│
87+
│┌┬────────────────────┐ │
88+
│││axis 0 │ │
89+
│├┼────────────────────┤ │
90+
│││ -9.89e-1 -6.53e-1 │ │
91+
│└┴────────────────────┘ │
92+
└─────────────────────────────┘
93+
94+
--- Test 8: Gradient flow (3-way) ---
95+
grad3_result (sin of 3-way stacked):
96+
HERE: test/operations/test_block_tensor.ml:126:21
97+
┌─────────────────────────────────────┐
98+
│[84]: sin_grad3_result shape 0:3,1:2 │
99+
│┌──────┬──────────────────┐ │
100+
││ │axis 1 │ │
101+
│├──────┼──────────────────┤ │
102+
││axis 0│ 4.79e-1 9.97e-1 │ │
103+
││ │ 8.41e-1 9.09e-1 │ │
104+
││ │ 1.41e-1 9.98e-2 │ │
105+
│└──────┴──────────────────┘ │
106+
└─────────────────────────────────────┘
107+
108+
Gradient of h1:
109+
HERE: test/operations/test_block_tensor.ml:128:21
110+
┌────────────────────────┐
111+
│[70]: 0.5,1.5 shape 0:2 │
112+
│┌┬───────────────┐ │
113+
│││axis 0 │ │
114+
│├┼───────────────┤ │
115+
│││ 5.00e-1 1.50 │ │
116+
│└┴───────────────┘ │
117+
└────────────────────────┘
118+
┌─────────────────────────────────────┐
119+
│[70]: 0.5,1.5 shape 0:2 grad_0.5,1.5│
120+
│┌┬──────────────────┐ │
121+
│││axis 0 │ │
122+
│├┼──────────────────┤ │
123+
│││ 8.77e-1 7.07e-2 │ │
124+
│└┴──────────────────┘ │
125+
└─────────────────────────────────────┘
126+
127+
Gradient of h2:
128+
HERE: test/operations/test_block_tensor.ml:130:21
129+
┌────────────────────┐
130+
│[72]: 1,2 shape 0:2 │
131+
│┌┬────────────┐ │
132+
│││axis 0 │ │
133+
│├┼────────────┤ │
134+
│││ 1.00 2.00 │ │
135+
│└┴────────────┘ │
136+
└────────────────────┘
137+
┌─────────────────────────────┐
138+
│[72]: 1,2 shape 0:2 grad_1,2│
139+
│┌┬───────────────────┐ │
140+
│││axis 0 │ │
141+
│├┼───────────────────┤ │
142+
│││ 5.40e-1 -4.16e-1 │ │
143+
│└┴───────────────────┘ │
144+
└─────────────────────────────┘
145+
146+
Gradient of h3:
147+
HERE: test/operations/test_block_tensor.ml:132:21
148+
┌──────────────────────┐
149+
│[74]: 3,0.1 shape 0:2 │
150+
│┌┬───────────────┐ │
151+
│││axis 0 │ │
152+
│├┼───────────────┤ │
153+
│││ 3.00 1.00e-1 │ │
154+
│└┴───────────────┘ │
155+
└──────────────────────┘
156+
┌─────────────────────────────────┐
157+
│[74]: 3,0.1 shape 0:2 grad_3,0.1│
158+
│┌┬───────────────────┐ │
159+
│││axis 0 │ │
160+
│├┼───────────────────┤ │
161+
│││ -9.89e-1 9.95e-1 │ │
162+
│└┴───────────────────┘ │
163+
└─────────────────────────────────┘
164+
165+
=== Block Tensor Literal Tests Complete ===

0 commit comments

Comments
 (0)