Skip to content

Commit f3ace5b

Browse files
lukstaficlaude
andcommitted
fix: support nested block tensor literals [[ta; tb]; [tc; td]]
Add Concat-Concat element-wise unification case to row.ml's unify_dim, allowing the shape solver to equate structurally-matching Concat dimensions from inner block tensor results when used in outer block tensors. Update test and docs to demonstrate direct nesting (2x2 block matrix). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 142e941 commit f3ace5b

4 files changed

Lines changed: 54 additions & 43 deletions

File tree

docs/syntax_extensions.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,8 @@ will trigger block tensor interpretation.
574574
let%op stacked = [v1; v2; v3]
575575
(* Shape: if each vi has output [d], result has output [3, d] *)
576576
577-
(* 2x2 block matrix from scalars -- nesting currently requires let bindings *)
578-
let%op row1 = [a; b] in
579-
let%op row2 = [c; d] in
580-
let%op mat = (row1, row2) ++^ "a; b => a^b"
577+
(* 2x2 block matrix from vectors via direct nesting *)
578+
let%op mat = [[a; b]; [c; d]]
581579
582580
(* Batch two samples *)
583581
let%op batched = [|sample1; sample2|]
@@ -592,8 +590,8 @@ let%op input_block = (x, y)
592590
- Single-element block tensors like `[ta]` act as unsqueeze (add a size-1 leading axis).
593591
- Tuple syntax `(ta, tb)` only works at the top level of a `%op` expression. Tuples inside function
594592
arguments (e.g., `f (ta, tb)`) are preserved as regular OCaml tuples.
595-
- Direct nesting like `[[ta; tb]; [tc; td]]` is currently limited by shape inference; use
596-
intermediate let bindings or explicit `++^` for nested block construction.
593+
- Nesting works: `[[ta; tb]; [tc; td]]` constructs a 2x2 block matrix (two new output axes).
594+
Components at each nesting level must have the same shape.
597595

598596
### Capturing the dimensions of selected axes for further computation or to add shape constraints
599597

tensor/row.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,6 +1884,11 @@ let%track6_sexp rec unify_dim ~stage origin (eq : dim * dim) env : constraint_ l
18841884
(more_ineqs @ ineqs, env)
18851885
in
18861886
List.fold ~init:(ineqs, env) dim_eqs ~f
1887+
| Concat dims1, Concat dims2 when List.length dims1 = List.length dims2 ->
1888+
(* Unify element-wise: Concat [a; b] = Concat [c; d] iff a = c and b = d *)
1889+
List.fold2_exn dims1 dims2 ~init:([], env) ~f:(fun (ineqs, env) d1 d2 ->
1890+
let more_ineqs, env = unify_dim ~stage origin (d1, d2) env in
1891+
(more_ineqs @ ineqs, env))
18871892
| dim1, dim2 ->
18881893
(* Note: at the unify_dim phase, it's strict equality (no broadcasting). *)
18891894
raise @@ Shape_error ("solved dimensions for axis: mismatch", [ Dim_mismatch [ dim1; dim2 ] ])

test/operations/test_block_tensor.expected

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,22 @@ HERE: test/operations/test_block_tensor.ml:59:21
3232
; [ 7.00 ; 8.00 ; 9.00 ]
3333
]
3434

35-
--- Test 5: Scalars in block tensor ---
36-
HERE: test/operations/test_block_tensor.ml:73:21
37-
[48]: ++^_scalar_stack shape 0:3 [ 1.00 ; 2.00 ; 3.00 ]
35+
--- Test 5: Nested block matrix [[a; b]; [c; d]] ---
36+
HERE: test/operations/test_block_tensor.ml:78:21
37+
[60]: ++^_block_matrix shape 0:2,1:2,2:2 [
38+
[ [ 1.00 ; 2.00 ] ; [ 3.00 ; 4.00 ] ]
39+
; [ [ 5.00 ; 6.00 ] ; [ 7.00 ; 8.00 ] ]
40+
]
3841

3942
--- Test 6: Single element [x1] ---
40-
HERE: test/operations/test_block_tensor.ml:80:21
41-
[52]: ++^_unsqueezed shape 0:1,1:3 [ [ 1.00 ; 2.00 ; 3.00 ] ]
43+
HERE: test/operations/test_block_tensor.ml:85:21
44+
[64]: ++^_unsqueezed shape 0:1,1:3 [ [ 1.00 ; 2.00 ; 3.00 ] ]
4245

4346
--- Test 7: Gradient flow (2-way) ---
4447
grad_result (sin of stacked):
45-
HERE: test/operations/test_block_tensor.ml:99:21
48+
HERE: test/operations/test_block_tensor.ml:104:21
4649
┌────────────────────────────────────┐
47-
│[64]: sin_grad_result shape 0:2,1:2 │
50+
│[76]: sin_grad_result shape 0:2,1:2 │
4851
│┌──────┬───────────────────┐ │
4952
││ │axis 1 │ │
5053
│├──────┼───────────────────┤ │
@@ -54,17 +57,17 @@ HERE: test/operations/test_block_tensor.ml:99:21
5457
└────────────────────────────────────┘
5558

5659
Gradient of g1 (should be cos of original):
57-
HERE: test/operations/test_block_tensor.ml:101:21
60+
HERE: test/operations/test_block_tensor.ml:106:21
5861
┌────────────────────┐
59-
│[54]: 1,2 shape 0:2 │
62+
│[66]: 1,2 shape 0:2 │
6063
│┌┬────────────┐ │
6164
│││axis 0 │ │
6265
│├┼────────────┤ │
6366
│││ 1.00 2.00 │ │
6467
│└┴────────────┘ │
6568
└────────────────────┘
6669
┌─────────────────────────────┐
67-
│[54]: 1,2 shape 0:2 grad_1,2│
70+
│[66]: 1,2 shape 0:2 grad_1,2│
6871
│┌┬───────────────────┐ │
6972
│││axis 0 │ │
7073
│├┼───────────────────┤ │
@@ -73,17 +76,17 @@ HERE: test/operations/test_block_tensor.ml:101:21
7376
└─────────────────────────────┘
7477

7578
Gradient of g2:
76-
HERE: test/operations/test_block_tensor.ml:103:21
79+
HERE: test/operations/test_block_tensor.ml:108:21
7780
┌────────────────────┐
78-
│[56]: 3,4 shape 0:2 │
81+
│[68]: 3,4 shape 0:2 │
7982
│┌┬────────────┐ │
8083
│││axis 0 │ │
8184
│├┼────────────┤ │
8285
│││ 3.00 4.00 │ │
8386
│└┴────────────┘ │
8487
└────────────────────┘
8588
┌─────────────────────────────┐
86-
│[56]: 3,4 shape 0:2 grad_3,4│
89+
│[68]: 3,4 shape 0:2 grad_3,4│
8790
│┌┬────────────────────┐ │
8891
│││axis 0 │ │
8992
│├┼────────────────────┤ │
@@ -93,9 +96,9 @@ HERE: test/operations/test_block_tensor.ml:103:21
9396

9497
--- Test 8: Gradient flow (3-way) ---
9598
grad3_result (sin of 3-way stacked):
96-
HERE: test/operations/test_block_tensor.ml:126:21
99+
HERE: test/operations/test_block_tensor.ml:131:21
97100
┌─────────────────────────────────────┐
98-
│[84]: sin_grad3_result shape 0:3,1:2 │
101+
│[96]: sin_grad3_result shape 0:3,1:2 │
99102
│┌──────┬──────────────────┐ │
100103
││ │axis 1 │ │
101104
│├──────┼──────────────────┤ │
@@ -106,17 +109,17 @@ HERE: test/operations/test_block_tensor.ml:126:21
106109
└─────────────────────────────────────┘
107110

108111
Gradient of h1:
109-
HERE: test/operations/test_block_tensor.ml:128:21
112+
HERE: test/operations/test_block_tensor.ml:133:21
110113
┌────────────────────────┐
111-
│[70]: 0.5,1.5 shape 0:2 │
114+
│[82]: 0.5,1.5 shape 0:2 │
112115
│┌┬───────────────┐ │
113116
│││axis 0 │ │
114117
│├┼───────────────┤ │
115118
│││ 5.00e-1 1.50 │ │
116119
│└┴───────────────┘ │
117120
└────────────────────────┘
118121
┌─────────────────────────────────────┐
119-
│[70]: 0.5,1.5 shape 0:2 grad_0.5,1.5│
122+
│[82]: 0.5,1.5 shape 0:2 grad_0.5,1.5│
120123
│┌┬──────────────────┐ │
121124
│││axis 0 │ │
122125
│├┼──────────────────┤ │
@@ -125,17 +128,17 @@ HERE: test/operations/test_block_tensor.ml:128:21
125128
└─────────────────────────────────────┘
126129

127130
Gradient of h2:
128-
HERE: test/operations/test_block_tensor.ml:130:21
131+
HERE: test/operations/test_block_tensor.ml:135:21
129132
┌────────────────────┐
130-
│[72]: 1,2 shape 0:2 │
133+
│[84]: 1,2 shape 0:2 │
131134
│┌┬────────────┐ │
132135
│││axis 0 │ │
133136
│├┼────────────┤ │
134137
│││ 1.00 2.00 │ │
135138
│└┴────────────┘ │
136139
└────────────────────┘
137140
┌─────────────────────────────┐
138-
│[72]: 1,2 shape 0:2 grad_1,2│
141+
│[84]: 1,2 shape 0:2 grad_1,2│
139142
│┌┬───────────────────┐ │
140143
│││axis 0 │ │
141144
│├┼───────────────────┤ │
@@ -144,17 +147,17 @@ HERE: test/operations/test_block_tensor.ml:130:21
144147
└─────────────────────────────┘
145148

146149
Gradient of h3:
147-
HERE: test/operations/test_block_tensor.ml:132:21
150+
HERE: test/operations/test_block_tensor.ml:137:21
148151
┌──────────────────────┐
149-
│[74]: 3,0.1 shape 0:2 │
152+
│[86]: 3,0.1 shape 0:2 │
150153
│┌┬───────────────┐ │
151154
│││axis 0 │ │
152155
│├┼───────────────┤ │
153156
│││ 3.00 1.00e-1 │ │
154157
│└┴───────────────┘ │
155158
└──────────────────────┘
156159
┌─────────────────────────────────┐
157-
│[74]: 3,0.1 shape 0:2 grad_3,0.1│
160+
│[86]: 3,0.1 shape 0:2 grad_3,0.1│
158161
│┌┬───────────────────┐ │
159162
│││axis 0 │ │
160163
│├┼───────────────────┤ │

test/operations/test_block_tensor.ml

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,24 @@ let () =
5858
let ctx = Train.forward_once ctx triple in
5959
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline triple;
6060

61-
(* --- Test 5: Scalars in block tensor --- *)
62-
(* NOTE: Nesting like [[s1; s2]; [s3; s4]] or [row1; row2] where row1/row2 are
63-
themselves block tensor results is currently limited by shape inference:
64-
the Concat dimension types produced by inner blocks can't be reconciled
65-
with the row variable (...) in the outer concat spec. Use explicit ++^ for nesting. *)
66-
printf "\n--- Test 5: Scalars in block tensor ---\n%!";
67-
let s1 = PDSL.ndarray [| 1.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[] () in
68-
let s2 = PDSL.ndarray [| 2.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[] () in
69-
let s3 = PDSL.ndarray [| 3.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[] () in
70-
let%op scalar_stack = [s1; s2; s3] in
71-
Train.set_hosted scalar_stack.value;
72-
let ctx = Train.forward_once ctx scalar_stack in
73-
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline scalar_stack;
61+
(* --- Test 5: Nested block matrix [[a; b]; [c; d]] --- *)
62+
printf "\n--- Test 5: Nested block matrix [[a; b]; [c; d]] ---\n%!";
63+
let s1 =
64+
PDSL.ndarray [| 1.0; 2.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[ 2 ] ()
65+
in
66+
let s2 =
67+
PDSL.ndarray [| 3.0; 4.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[ 2 ] ()
68+
in
69+
let s3 =
70+
PDSL.ndarray [| 5.0; 6.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[ 2 ] ()
71+
in
72+
let s4 =
73+
PDSL.ndarray [| 7.0; 8.0 |] ~batch_dims:[] ~input_dims:[] ~output_dims:[ 2 ] ()
74+
in
75+
let%op block_matrix = [[s1; s2]; [s3; s4]] in
76+
Train.set_hosted block_matrix.value;
77+
let ctx = Train.forward_once ctx block_matrix in
78+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline block_matrix;
7479

7580
(* --- Test 6: Single element [x1] — unsqueeze --- *)
7681
printf "\n--- Test 6: Single element [x1] ---\n%!";

0 commit comments

Comments
 (0)