Skip to content

Commit 327a6cf

Browse files
lukstaficlaude
andcommitted
Add decoder-only autoregressive transformer example on Names dataset (#57)
Add decoder_only_block and decoder_only to nn_blocks.ml as reusable building blocks for autoregressive language models (masked self-attention + FFN, no cross-attention). Add test/training/transformer_names.ml: a complete training + generation example using character-level encoding on the Names dataset with causal masking, SGD training, and autoregressive token-by-token generation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 62d32a2 commit 327a6cf

4 files changed

Lines changed: 329 additions & 0 deletions

File tree

lib/nn_blocks.ml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,34 @@ let%op transformer_encoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1
228228
let x1 = ln1 (input + mha ~train_step input) in
229229
ln2 (x1 + ffn x1)
230230

231+
(** Decoder-only transformer block: masked self-attention + FFN with post-norm LayerNorm.
232+
Like {!transformer_encoder_block} but accepts a [~mask] parameter for causal masking.
233+
No cross-attention — suitable for autoregressive language models. *)
234+
let%op decoder_only_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5)
235+
?(dropout_rate = 0.0) ?(pos_embed = No_pos_embed) () =
236+
let masked_mha =
237+
multi_head_attention ~label:("masked_mha" :: label) ~num_heads ~d_k ~d_v ~dropout_rate ~pos_embed
238+
()
239+
in
240+
let ffn = mlp ~label:("ffn" :: label) ~hid_dims:[ d_ff ] () in
241+
let ln1 = layer_norm ~label:("ln1" :: label) ~epsilon () in
242+
let ln2 = layer_norm ~label:("ln2" :: label) ~epsilon () in
243+
fun ~train_step x ~mask ->
244+
let x1 = ln1 (x + masked_mha ~train_step ~mask x) in
245+
ln2 (x1 + ffn x1)
246+
247+
(** Stack of {!decoder_only_block} layers. *)
248+
let decoder_only ~label ~num_layers ~num_heads ~d_k ~d_v ~d_ff ?epsilon ?dropout_rate
249+
?(pos_embed = No_pos_embed) () =
250+
let layers =
251+
List.init num_layers ~f:(fun i ->
252+
decoder_only_block
253+
~label:(("layer" ^ Int.to_string i) :: label)
254+
~num_heads ~d_k ~d_v ~d_ff ?epsilon ?dropout_rate ~pos_embed ())
255+
in
256+
fun ~train_step x ~mask ->
257+
List.fold layers ~init:x ~f:(fun x layer -> layer ~train_step x ~mask)
258+
231259
(* Cross-attention does not apply RoPE — position encoding is for self-attention only. *)
232260
let%op cross_attention ~label ~num_heads ~d_k ~d_v ?temperature ?(dropout_rate = 0.0) () ~train_step
233261
x ~enc_output =

test/training/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,14 @@
8383
(libraries ocannl)
8484
(preprocess
8585
(pps ppx_here ppx_ocannl)))
86+
87+
(test
88+
(name transformer_names)
89+
(package neural_nets_lib)
90+
(modules transformer_names)
91+
(deps
92+
ocannl_config
93+
(env_var OCANNL_BACKEND))
94+
(libraries ocannl dataprep)
95+
(preprocess
96+
(pps ppx_here ppx_ocannl)))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
Names loaded: 32033
4+
Training examples: 32032
5+
Epoch 0, loss below threshold=true
6+
Epoch 5, loss below threshold=true
7+
Epoch 9, loss below threshold=true
8+
9+
remahitaipt
10+
dava

test/training/transformer_names.ml

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
open Base
2+
open Ocannl
3+
open Stdio
4+
module Tn = Ir.Tnode
5+
module IDX = Train.IDX
6+
open Nn_blocks.DSL_modules
7+
module Asgns = Ir.Assignments
8+
9+
(* === Configuration === *)
10+
11+
let ctx_len = 16
12+
let eff_seq_len = ctx_len - 1
13+
let d_model = 16
14+
let num_heads = 2
15+
let d_k = 8
16+
let d_v = 8
17+
let d_ff = 32
18+
let vocab_size = Dataprep.Names.dict_size
19+
let batch_size = 32
20+
let epochs = 10
21+
let pad_char = ' '
22+
let pad_idx = Dataprep.Names.char_index pad_char
23+
let bos_idx = Dataprep.Names.char_index '.'
24+
25+
(* === Data preparation === *)
26+
27+
(** Convert a name to a fixed-length integer sequence.
28+
"emma" -> [0; 5; 13; 13; 1; 0; 1; 1; ...] where 0='.' and 1=' ' (padding).
29+
Returns (input_indices, target_indices) for teacher forcing. *)
30+
let name_to_sequences name =
31+
let chars = '.' :: (String.to_list name @ [ '.' ]) in
32+
let len = List.length chars in
33+
let padded =
34+
if len >= ctx_len then List.take chars ctx_len
35+
else chars @ List.init (ctx_len - len) ~f:(fun _ -> pad_char)
36+
in
37+
let indices = List.map padded ~f:Dataprep.Names.char_index in
38+
let input_indices = List.take indices eff_seq_len in
39+
let target_indices = List.tl_exn (List.take indices ctx_len) in
40+
(Array.of_list input_indices, Array.of_list target_indices)
41+
42+
let prepare_dataset () =
43+
let names = Dataprep.Names.read_names () in
44+
let num_names = List.length names in
45+
printf "Names loaded: %d\n%!" num_names;
46+
let pairs = List.map names ~f:name_to_sequences in
47+
let inputs = Array.of_list (List.map pairs ~f:fst) in
48+
let targets = Array.of_list (List.map pairs ~f:snd) in
49+
let num_examples = Array.length inputs in
50+
(* Round down to multiple of batch_size *)
51+
let num_examples = num_examples - (num_examples % batch_size) in
52+
printf "Training examples: %d\n%!" num_examples;
53+
(inputs, targets, num_examples)
54+
55+
let seqs_to_flat_one_hot (seqs : int array array) ~offset =
56+
let flat = Array.create ~len:(batch_size * eff_seq_len * vocab_size) 0. in
57+
for i = 0 to batch_size - 1 do
58+
for t = 0 to eff_seq_len - 1 do
59+
let base = ((i * eff_seq_len) + t) * vocab_size in
60+
flat.(base + seqs.(offset + i).(t)) <- 1.
61+
done
62+
done;
63+
flat
64+
65+
(* === Main === *)
66+
67+
let () =
68+
Utils.settings.fixed_state_for_init <- Some 3;
69+
Tensor.unsafe_reinitialize ();
70+
71+
let train_inputs, train_targets, num_examples = prepare_dataset () in
72+
let n_batches = num_examples / batch_size in
73+
74+
let step_n, bindings = IDX.get_static_symbol IDX.empty in
75+
let total_tokens = batch_size * eff_seq_len in
76+
77+
(* === Data tensors === *)
78+
let make_data_tensor label =
79+
let open Bigarray in
80+
let ga = Genarray.create Float32 c_layout [| batch_size; eff_seq_len; vocab_size |] in
81+
Bigarray.Genarray.fill ga 0.;
82+
let nd = Ir.Ndarray.as_array Ir.Ops.Single ga in
83+
Tensor.term ~init_data:(Reshape nd) ~grad_spec:If_needed ~label:[ label ]
84+
~batch_dims:[ batch_size; eff_seq_len ] ~input_dims:[] ~output_dims:[ vocab_size ] ()
85+
in
86+
let input_batch = make_data_tensor "input_batch" in
87+
let target_batch = make_data_tensor "target_batch" in
88+
89+
(* === Causal mask === *)
90+
let mask =
91+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ eff_seq_len ] ~i:[ eff_seq_len ] ~o:[]
92+
~f:(function
93+
| [| s; t |] -> if s >= t then 1. else 0.
94+
| _ -> failwith "unexpected mask indices")
95+
()
96+
in
97+
98+
(* === Model ===
99+
Decoder-only transformer: masked self-attention + FFN with residual connections.
100+
Layer norm is omitted for this small model to keep generated code compact and
101+
avoid the gradient signal issue noted in fsm_transformer.ml.
102+
Uses multi_head_attention with ~mask for causal masking. *)
103+
let open Nn_blocks in
104+
let mha = multi_head_attention ~label:[ "mha" ] ~num_heads ~d_k ~d_v () in
105+
let ffn = Nn_blocks.mlp ~label:[ "ffn" ] ~hid_dims:[ d_ff ] () in
106+
let%op build_model () =
107+
fun ~train_step ~mask input ->
108+
let embedded = ({ tok_embed; o = [ d_model ] } * input) + { pos_encoding } in
109+
let x1 = embedded + mha ~train_step ~mask embedded in
110+
let x2 = x1 + ffn x1 in
111+
{ w_out } * x2
112+
in
113+
let model = build_model () in
114+
115+
(* === Training computation === *)
116+
let train_logits = model ~train_step:(Some step_n) ~mask input_batch in
117+
let%op counts = exp train_logits in
118+
let%op probs = counts /. (counts ++ "...|... => ...|0") in
119+
let%op output_probs = (probs *. target_batch) ++ "...|... => ...|0" in
120+
let%op loss = neg (log output_probs) in
121+
let%op batch_loss = (loss ++ "...|... => 0") /. !..total_tokens in
122+
123+
let update = Train.grad_update batch_loss in
124+
let steps = epochs * n_batches in
125+
let%op learning_rate = 0.01 *. ((1.5 *. !..steps) - !@step_n) /. !..steps in
126+
let sgd = Train.sgd_update ~learning_rate batch_loss in
127+
128+
(* === Inference computation (forward-only, shares trained weights) === *)
129+
let infer_input =
130+
let open Bigarray in
131+
let ga = Genarray.create Float32 c_layout [| 1; eff_seq_len; vocab_size |] in
132+
Bigarray.Genarray.fill ga 0.;
133+
let nd = Ir.Ndarray.as_array Ir.Ops.Single ga in
134+
Tensor.term ~init_data:(Reshape nd) ~grad_spec:Prohibit_grad ~label:[ "infer_input" ]
135+
~batch_dims:[ 1; eff_seq_len ] ~input_dims:[] ~output_dims:[ vocab_size ] ()
136+
in
137+
let counter_n, infer_bindings = IDX.get_static_symbol IDX.empty in
138+
let%cd infer_logits = model ~train_step:None ~mask infer_input in
139+
let%cd infer_comp =
140+
~~("names infer";
141+
infer_logits.forward;
142+
{ dice } =: uniform_at !@counter_n)
143+
in
144+
145+
(* === Compile === *)
146+
let ctx = Context.auto () in
147+
let ctx = Train.init_params ctx bindings batch_loss in
148+
Train.set_on_host input_batch.value;
149+
Train.set_on_host target_batch.value;
150+
(* Recenter all model parameters from uniform [0,1) to [-0.25, 0.25).
151+
OCANNL's default uniform1 init produces all-positive weights; through the
152+
transformer's Q*K^T attention scores this causes extreme values and exp overflow.
153+
Same mitigation as fsm_transformer.ml. *)
154+
Set.iter batch_loss.Tensor.params ~f:(fun p ->
155+
let tn = p.Tensor.value in
156+
Train.set_on_host tn;
157+
let vals = Tn.get_values tn in
158+
Array.iteri vals ~f:(fun i v -> vals.(i) <- 0.5 *. (v -. 0.5));
159+
Tn.set_values tn vals);
160+
Train.set_on_host infer_logits.value;
161+
Train.set_on_host infer_input.value;
162+
163+
(* Compile training routine *)
164+
let train_comp = Asgns.sequence [ update; sgd ] in
165+
Set.iter (snd @@ Asgns.collect_nodes_guess_output train_comp.Asgns.asgns) ~f:Train.set_hosted;
166+
let ctx, sgd_step = Context.compile ctx train_comp bindings in
167+
168+
(* Compile inference routine *)
169+
Set.iter (snd @@ Asgns.collect_nodes_guess_output infer_comp.Asgns.asgns) ~f:Train.set_hosted;
170+
let infer_comp =
171+
{ infer_comp with
172+
Asgns.embedded_nodes = Set.add infer_comp.Asgns.embedded_nodes mask.value
173+
}
174+
in
175+
let ctx, infer_routine = Context.compile ctx infer_comp infer_bindings in
176+
177+
let open Operation.At in
178+
let step_ref = IDX.find_exn (Context.bindings sgd_step) step_n in
179+
let counter_ref = IDX.find_exn (Context.bindings infer_routine) counter_n in
180+
counter_ref := 0;
181+
Train.set_on_host batch_loss.value;
182+
183+
(* === Training loop ===
184+
Random baseline: ln(28) ≈ 3.33 per token, epoch sum ≈ 3.33 * n_batches.
185+
We check loss at first, middle, and last epochs. *)
186+
let epoch_loss_limit_first = 2.0 *. Float.of_int n_batches in
187+
let epoch_loss_limit_mid = 1.4 *. Float.of_int n_batches in
188+
let epoch_loss_limit_last = 1.3 *. Float.of_int n_batches in
189+
for epoch = 0 to epochs - 1 do
190+
let epoch_loss = ref 0. in
191+
for batch = 0 to n_batches - 1 do
192+
let offset = batch * batch_size in
193+
Tn.set_values input_batch.value (seqs_to_flat_one_hot train_inputs ~offset);
194+
Tn.set_values target_batch.value (seqs_to_flat_one_hot train_targets ~offset);
195+
let ctx' = Context.run ctx sgd_step in
196+
ignore (ctx' : Context.t);
197+
epoch_loss := !epoch_loss +. batch_loss.@[0];
198+
Int.incr step_ref
199+
done;
200+
if epoch = 0 || epoch = epochs / 2 || epoch = epochs - 1 then (
201+
let limit =
202+
if epoch = 0 then epoch_loss_limit_first
203+
else if epoch = epochs / 2 then epoch_loss_limit_mid
204+
else epoch_loss_limit_last
205+
in
206+
printf "Epoch %d, loss below threshold=%b\n%!" epoch Float.(!epoch_loss < limit))
207+
done;
208+
209+
(* === Autoregressive generation ===
210+
Generate names token-by-token, sampling from the model's output distribution.
211+
Uses the CDF-based sampling pattern from bigram_mlp.ml. *)
212+
let set_one_hot_seq context =
213+
let flat = Array.create ~len:(1 * eff_seq_len * vocab_size) 0. in
214+
for t = 0 to eff_seq_len - 1 do
215+
let base = t * vocab_size in
216+
flat.(base + context.(t)) <- 1.
217+
done;
218+
Tn.set_values infer_input.value flat
219+
in
220+
221+
let gen_name () =
222+
let context = Array.create ~len:eff_seq_len pad_idx in
223+
context.(0) <- bos_idx;
224+
let rec aux pos =
225+
if pos >= eff_seq_len then
226+
(* Max length reached — extract what we have *)
227+
let name = Buffer.create 16 in
228+
for i = 1 to eff_seq_len - 1 do
229+
let c = List.nth_exn Dataprep.Names.letters_with_dot context.(i) in
230+
if not (Char.equal c '.' || Char.equal c ' ') then Buffer.add_char name c
231+
done;
232+
Buffer.contents name
233+
else begin
234+
set_one_hot_seq context;
235+
Int.incr counter_ref;
236+
let _ctx = Context.run ctx infer_routine in
237+
let dice_value = dice.@[0] in
238+
239+
(* Compute softmax probabilities at position (pos-1) in the output
240+
(the model predicts token at position pos given input up to pos-1). *)
241+
let logits = Array.init vocab_size ~f:(fun v ->
242+
infer_logits.@{[| 0; pos - 1; v |]}) in
243+
let max_logit = Array.fold logits ~init:Float.neg_infinity ~f:Float.max in
244+
let exp_logits = Array.map logits ~f:(fun l -> Float.exp (l -. max_logit)) in
245+
let sum_exp = Array.fold exp_logits ~init:0. ~f:( +. ) in
246+
let probs = Array.map exp_logits ~f:(fun e -> e /. sum_exp) in
247+
248+
(* CDF-based sampling *)
249+
let max_i = vocab_size - 1 in
250+
let rec sample i acc =
251+
if i >= max_i then i
252+
else
253+
let new_acc = acc +. probs.(i) in
254+
if Float.(new_acc > dice_value) then i
255+
else sample (i + 1) new_acc
256+
in
257+
let sampled_idx = sample 0 0. in
258+
let sampled_char = List.nth_exn Dataprep.Names.letters_with_dot sampled_idx in
259+
260+
if Char.equal sampled_char '.' && pos > 1 then begin
261+
(* EOS — extract name from positions 1..pos-1 *)
262+
let name = Buffer.create 16 in
263+
for i = 1 to pos - 1 do
264+
let c = List.nth_exn Dataprep.Names.letters_with_dot context.(i) in
265+
if not (Char.equal c '.' || Char.equal c ' ') then Buffer.add_char name c
266+
done;
267+
Buffer.contents name
268+
end
269+
else begin
270+
context.(pos) <- sampled_idx;
271+
aux (pos + 1)
272+
end
273+
end
274+
in
275+
aux 1
276+
in
277+
278+
(* Generate very few names because different hardware backends diverge quickly. *)
279+
let names = Array.init 3 ~f:(fun _ -> gen_name ()) in
280+
Array.iter names ~f:print_endline

0 commit comments

Comments
 (0)