Skip to content

Commit 6eb00e5

Browse files
lukstaficlaude
andcommitted
Add regression test for decoder_only_block/decoder_only API
Exercises the new Nn_blocks.decoder_only helper with a 2-layer stack, causal mask, and forward pass, validating output shape. This ensures the new public API added in the previous commit has CI coverage. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 327a6cf commit 6eb00e5

3 files changed

Lines changed: 65 additions & 0 deletions

File tree

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
Testing decoder_only (2-layer stack)
4+
Output shape:
5+
((batch
6+
((dims
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1)))))
8+
(Dim ((d 4) (label ()) (proj_id ((Proj_id 2)))))))
9+
(bcast Broadcastable) (prov (((sh_id 307) (kind Batch))))))
10+
(input
11+
((dims ()) (bcast Broadcastable) (prov (((sh_id 307) (kind Input))))))
12+
(output
13+
((dims ((Dim ((d 16) (label ()) (proj_id ((Proj_id 3)))))))
14+
(bcast Broadcastable) (prov (((sh_id 307) (kind Output))))))
15+
(batch_padding ()) (input_padding ()) (output_padding ()) (padding_elem ())
16+
(id 307) (debug_name layer_norm))
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
open! Base
2+
open Ocannl.Nn_blocks.DSL_modules
3+
4+
let () =
5+
let ctx = Context.auto () in
6+
let batch_size = 2 in
7+
let seq_len = 4 in
8+
let d_model = 16 in
9+
let num_heads = 2 in
10+
let d_ff = 32 in
11+
12+
Stdio.printf "Testing decoder_only (2-layer stack)\n";
13+
14+
(* decoder_only internally creates decoder_only_block instances,
15+
so this exercises both functions. *)
16+
let stack =
17+
Ocannl.Nn_blocks.decoder_only ~label:[ "test_stack" ] ~num_layers:2 ~num_heads ~d_k:d_model
18+
~d_v:d_model ~d_ff ()
19+
in
20+
21+
let input =
22+
TDSL.range_of_shape ~label:[ "input" ] ~batch_dims:[ batch_size; seq_len ] ~input_dims:[]
23+
~output_dims:[ d_model ] ()
24+
in
25+
26+
let mask =
27+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ seq_len ] ~i:[ seq_len ] ~o:[]
28+
~f:(function
29+
| [| s; t |] -> if s >= t then 1. else 0.
30+
| _ -> failwith "unexpected mask indices")
31+
()
32+
in
33+
34+
let output = stack ~train_step:None input ~mask in
35+
let _ctx = Ocannl.Train.forward_once ctx output in
36+
37+
Stdio.printf "Output shape:\n%s\n%!"
38+
(Sexp.to_string_hum ([%sexp_of: Shape.t] output.Tensor.shape))

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,17 @@
377377
(preprocess
378378
(pps ppx_here ppx_ocannl ppx_sexp_conv)))
379379

380+
(test
381+
(name decoder_only_test)
382+
(package neural_nets_lib)
383+
(deps
384+
ocannl_config
385+
(env_var OCANNL_BACKEND))
386+
(modules decoder_only_test)
387+
(libraries base ocannl)
388+
(preprocess
389+
(pps ppx_here ppx_ocannl ppx_sexp_conv)))
390+
380391
(test
381392
(name test_extend_dsls)
382393
(package neural_nets_lib)

0 commit comments

Comments
 (0)