|
| 1 | +open Base |
| 2 | +open Ocannl |
| 3 | +open Stdio |
| 4 | +module Tn = Ir.Tnode |
| 5 | +module IDX = Train.IDX |
| 6 | +open Nn_blocks.DSL_modules |
| 7 | +module Asgns = Ir.Assignments |
| 8 | + |
| 9 | +(* === Configuration === *) |
| 10 | + |
| 11 | +let ctx_len = 16 |
| 12 | +let eff_seq_len = ctx_len - 1 |
| 13 | +let d_model = 16 |
| 14 | +let num_heads = 2 |
| 15 | +let d_k = 8 |
| 16 | +let d_v = 8 |
| 17 | +let d_ff = 32 |
| 18 | +let vocab_size = Dataprep.Names.dict_size |
| 19 | +let batch_size = 32 |
| 20 | +let epochs = 10 |
| 21 | +let pad_char = ' ' |
| 22 | +let pad_idx = Dataprep.Names.char_index pad_char |
| 23 | +let bos_idx = Dataprep.Names.char_index '.' |
| 24 | + |
| 25 | +(* === Data preparation === *) |
| 26 | + |
| 27 | +(** Convert a name to a fixed-length integer sequence. |
| 28 | + "emma" -> [0; 5; 13; 13; 1; 0; 1; 1; ...] where 0='.' and 1=' ' (padding). |
| 29 | + Returns (input_indices, target_indices) for teacher forcing. *) |
| 30 | +let name_to_sequences name = |
| 31 | + let chars = '.' :: (String.to_list name @ [ '.' ]) in |
| 32 | + let len = List.length chars in |
| 33 | + let padded = |
| 34 | + if len >= ctx_len then List.take chars ctx_len |
| 35 | + else chars @ List.init (ctx_len - len) ~f:(fun _ -> pad_char) |
| 36 | + in |
| 37 | + let indices = List.map padded ~f:Dataprep.Names.char_index in |
| 38 | + let input_indices = List.take indices eff_seq_len in |
| 39 | + let target_indices = List.tl_exn (List.take indices ctx_len) in |
| 40 | + (Array.of_list input_indices, Array.of_list target_indices) |
| 41 | + |
| 42 | +let prepare_dataset () = |
| 43 | + let names = Dataprep.Names.read_names () in |
| 44 | + let num_names = List.length names in |
| 45 | + printf "Names loaded: %d\n%!" num_names; |
| 46 | + let pairs = List.map names ~f:name_to_sequences in |
| 47 | + let inputs = Array.of_list (List.map pairs ~f:fst) in |
| 48 | + let targets = Array.of_list (List.map pairs ~f:snd) in |
| 49 | + let num_examples = Array.length inputs in |
| 50 | + (* Round down to multiple of batch_size *) |
| 51 | + let num_examples = num_examples - (num_examples % batch_size) in |
| 52 | + printf "Training examples: %d\n%!" num_examples; |
| 53 | + (inputs, targets, num_examples) |
| 54 | + |
| 55 | +let seqs_to_flat_one_hot (seqs : int array array) ~offset = |
| 56 | + let flat = Array.create ~len:(batch_size * eff_seq_len * vocab_size) 0. in |
| 57 | + for i = 0 to batch_size - 1 do |
| 58 | + for t = 0 to eff_seq_len - 1 do |
| 59 | + let base = ((i * eff_seq_len) + t) * vocab_size in |
| 60 | + flat.(base + seqs.(offset + i).(t)) <- 1. |
| 61 | + done |
| 62 | + done; |
| 63 | + flat |
| 64 | + |
| 65 | +(* === Main === *) |
| 66 | + |
| 67 | +let () = |
| 68 | + Utils.settings.fixed_state_for_init <- Some 3; |
| 69 | + Tensor.unsafe_reinitialize (); |
| 70 | + |
| 71 | + let train_inputs, train_targets, num_examples = prepare_dataset () in |
| 72 | + let n_batches = num_examples / batch_size in |
| 73 | + |
| 74 | + let step_n, bindings = IDX.get_static_symbol IDX.empty in |
| 75 | + let total_tokens = batch_size * eff_seq_len in |
| 76 | + |
| 77 | + (* === Data tensors === *) |
| 78 | + let make_data_tensor label = |
| 79 | + let open Bigarray in |
| 80 | + let ga = Genarray.create Float32 c_layout [| batch_size; eff_seq_len; vocab_size |] in |
| 81 | + Bigarray.Genarray.fill ga 0.; |
| 82 | + let nd = Ir.Ndarray.as_array Ir.Ops.Single ga in |
| 83 | + Tensor.term ~init_data:(Reshape nd) ~grad_spec:If_needed ~label:[ label ] |
| 84 | + ~batch_dims:[ batch_size; eff_seq_len ] ~input_dims:[] ~output_dims:[ vocab_size ] () |
| 85 | + in |
| 86 | + let input_batch = make_data_tensor "input_batch" in |
| 87 | + let target_batch = make_data_tensor "target_batch" in |
| 88 | + |
| 89 | + (* === Causal mask === *) |
| 90 | + let mask = |
| 91 | + NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ eff_seq_len ] ~i:[ eff_seq_len ] ~o:[] |
| 92 | + ~f:(function |
| 93 | + | [| s; t |] -> if s >= t then 1. else 0. |
| 94 | + | _ -> failwith "unexpected mask indices") |
| 95 | + () |
| 96 | + in |
| 97 | + |
| 98 | + (* === Model === |
| 99 | + Decoder-only transformer: masked self-attention + FFN with residual connections. |
| 100 | + Layer norm is omitted for this small model to keep generated code compact and |
| 101 | + avoid the gradient signal issue noted in fsm_transformer.ml. |
| 102 | + Uses multi_head_attention with ~mask for causal masking. *) |
| 103 | + let open Nn_blocks in |
| 104 | + let mha = multi_head_attention ~label:[ "mha" ] ~num_heads ~d_k ~d_v () in |
| 105 | + let ffn = Nn_blocks.mlp ~label:[ "ffn" ] ~hid_dims:[ d_ff ] () in |
| 106 | + let%op build_model () = |
| 107 | + fun ~train_step ~mask input -> |
| 108 | + let embedded = ({ tok_embed; o = [ d_model ] } * input) + { pos_encoding } in |
| 109 | + let x1 = embedded + mha ~train_step ~mask embedded in |
| 110 | + let x2 = x1 + ffn x1 in |
| 111 | + { w_out } * x2 |
| 112 | + in |
| 113 | + let model = build_model () in |
| 114 | + |
| 115 | + (* === Training computation === *) |
| 116 | + let train_logits = model ~train_step:(Some step_n) ~mask input_batch in |
| 117 | + let%op counts = exp train_logits in |
| 118 | + let%op probs = counts /. (counts ++ "...|... => ...|0") in |
| 119 | + let%op output_probs = (probs *. target_batch) ++ "...|... => ...|0" in |
| 120 | + let%op loss = neg (log output_probs) in |
| 121 | + let%op batch_loss = (loss ++ "...|... => 0") /. !..total_tokens in |
| 122 | + |
| 123 | + let update = Train.grad_update batch_loss in |
| 124 | + let steps = epochs * n_batches in |
| 125 | + let%op learning_rate = 0.01 *. ((1.5 *. !..steps) - !@step_n) /. !..steps in |
| 126 | + let sgd = Train.sgd_update ~learning_rate batch_loss in |
| 127 | + |
| 128 | + (* === Inference computation (forward-only, shares trained weights) === *) |
| 129 | + let infer_input = |
| 130 | + let open Bigarray in |
| 131 | + let ga = Genarray.create Float32 c_layout [| 1; eff_seq_len; vocab_size |] in |
| 132 | + Bigarray.Genarray.fill ga 0.; |
| 133 | + let nd = Ir.Ndarray.as_array Ir.Ops.Single ga in |
| 134 | + Tensor.term ~init_data:(Reshape nd) ~grad_spec:Prohibit_grad ~label:[ "infer_input" ] |
| 135 | + ~batch_dims:[ 1; eff_seq_len ] ~input_dims:[] ~output_dims:[ vocab_size ] () |
| 136 | + in |
| 137 | + let counter_n, infer_bindings = IDX.get_static_symbol IDX.empty in |
| 138 | + let%cd infer_logits = model ~train_step:None ~mask infer_input in |
| 139 | + let%cd infer_comp = |
| 140 | + ~~("names infer"; |
| 141 | + infer_logits.forward; |
| 142 | + { dice } =: uniform_at !@counter_n) |
| 143 | + in |
| 144 | + |
| 145 | + (* === Compile === *) |
| 146 | + let ctx = Context.auto () in |
| 147 | + let ctx = Train.init_params ctx bindings batch_loss in |
| 148 | + Train.set_on_host input_batch.value; |
| 149 | + Train.set_on_host target_batch.value; |
| 150 | + (* Recenter all model parameters from uniform [0,1) to [-0.25, 0.25). |
| 151 | + OCANNL's default uniform1 init produces all-positive weights; through the |
| 152 | + transformer's Q*K^T attention scores this causes extreme values and exp overflow. |
| 153 | + Same mitigation as fsm_transformer.ml. *) |
| 154 | + Set.iter batch_loss.Tensor.params ~f:(fun p -> |
| 155 | + let tn = p.Tensor.value in |
| 156 | + Train.set_on_host tn; |
| 157 | + let vals = Tn.get_values tn in |
| 158 | + Array.iteri vals ~f:(fun i v -> vals.(i) <- 0.5 *. (v -. 0.5)); |
| 159 | + Tn.set_values tn vals); |
| 160 | + Train.set_on_host infer_logits.value; |
| 161 | + Train.set_on_host infer_input.value; |
| 162 | + |
| 163 | + (* Compile training routine *) |
| 164 | + let train_comp = Asgns.sequence [ update; sgd ] in |
| 165 | + Set.iter (snd @@ Asgns.collect_nodes_guess_output train_comp.Asgns.asgns) ~f:Train.set_hosted; |
| 166 | + let ctx, sgd_step = Context.compile ctx train_comp bindings in |
| 167 | + |
| 168 | + (* Compile inference routine *) |
| 169 | + Set.iter (snd @@ Asgns.collect_nodes_guess_output infer_comp.Asgns.asgns) ~f:Train.set_hosted; |
| 170 | + let infer_comp = |
| 171 | + { infer_comp with |
| 172 | + Asgns.embedded_nodes = Set.add infer_comp.Asgns.embedded_nodes mask.value |
| 173 | + } |
| 174 | + in |
| 175 | + let ctx, infer_routine = Context.compile ctx infer_comp infer_bindings in |
| 176 | + |
| 177 | + let open Operation.At in |
| 178 | + let step_ref = IDX.find_exn (Context.bindings sgd_step) step_n in |
| 179 | + let counter_ref = IDX.find_exn (Context.bindings infer_routine) counter_n in |
| 180 | + counter_ref := 0; |
| 181 | + Train.set_on_host batch_loss.value; |
| 182 | + |
| 183 | + (* === Training loop === |
| 184 | + Random baseline: ln(28) ≈ 3.33 per token, epoch sum ≈ 3.33 * n_batches. |
| 185 | + We check loss at first, middle, and last epochs. *) |
| 186 | + let epoch_loss_limit_first = 2.0 *. Float.of_int n_batches in |
| 187 | + let epoch_loss_limit_mid = 1.4 *. Float.of_int n_batches in |
| 188 | + let epoch_loss_limit_last = 1.3 *. Float.of_int n_batches in |
| 189 | + for epoch = 0 to epochs - 1 do |
| 190 | + let epoch_loss = ref 0. in |
| 191 | + for batch = 0 to n_batches - 1 do |
| 192 | + let offset = batch * batch_size in |
| 193 | + Tn.set_values input_batch.value (seqs_to_flat_one_hot train_inputs ~offset); |
| 194 | + Tn.set_values target_batch.value (seqs_to_flat_one_hot train_targets ~offset); |
| 195 | + let ctx' = Context.run ctx sgd_step in |
| 196 | + ignore (ctx' : Context.t); |
| 197 | + epoch_loss := !epoch_loss +. batch_loss.@[0]; |
| 198 | + Int.incr step_ref |
| 199 | + done; |
| 200 | + if epoch = 0 || epoch = epochs / 2 || epoch = epochs - 1 then ( |
| 201 | + let limit = |
| 202 | + if epoch = 0 then epoch_loss_limit_first |
| 203 | + else if epoch = epochs / 2 then epoch_loss_limit_mid |
| 204 | + else epoch_loss_limit_last |
| 205 | + in |
| 206 | + printf "Epoch %d, loss below threshold=%b\n%!" epoch Float.(!epoch_loss < limit)) |
| 207 | + done; |
| 208 | + |
| 209 | + (* === Autoregressive generation === |
| 210 | + Generate names token-by-token, sampling from the model's output distribution. |
| 211 | + Uses the CDF-based sampling pattern from bigram_mlp.ml. *) |
| 212 | + let set_one_hot_seq context = |
| 213 | + let flat = Array.create ~len:(1 * eff_seq_len * vocab_size) 0. in |
| 214 | + for t = 0 to eff_seq_len - 1 do |
| 215 | + let base = t * vocab_size in |
| 216 | + flat.(base + context.(t)) <- 1. |
| 217 | + done; |
| 218 | + Tn.set_values infer_input.value flat |
| 219 | + in |
| 220 | + |
| 221 | + let gen_name () = |
| 222 | + let context = Array.create ~len:eff_seq_len pad_idx in |
| 223 | + context.(0) <- bos_idx; |
| 224 | + let rec aux pos = |
| 225 | + if pos >= eff_seq_len then |
| 226 | + (* Max length reached — extract what we have *) |
| 227 | + let name = Buffer.create 16 in |
| 228 | + for i = 1 to eff_seq_len - 1 do |
| 229 | + let c = List.nth_exn Dataprep.Names.letters_with_dot context.(i) in |
| 230 | + if not (Char.equal c '.' || Char.equal c ' ') then Buffer.add_char name c |
| 231 | + done; |
| 232 | + Buffer.contents name |
| 233 | + else begin |
| 234 | + set_one_hot_seq context; |
| 235 | + Int.incr counter_ref; |
| 236 | + let _ctx = Context.run ctx infer_routine in |
| 237 | + let dice_value = dice.@[0] in |
| 238 | + |
| 239 | + (* Compute softmax probabilities at position (pos-1) in the output |
| 240 | + (the model predicts token at position pos given input up to pos-1). *) |
| 241 | + let logits = Array.init vocab_size ~f:(fun v -> |
| 242 | + infer_logits.@{[| 0; pos - 1; v |]}) in |
| 243 | + let max_logit = Array.fold logits ~init:Float.neg_infinity ~f:Float.max in |
| 244 | + let exp_logits = Array.map logits ~f:(fun l -> Float.exp (l -. max_logit)) in |
| 245 | + let sum_exp = Array.fold exp_logits ~init:0. ~f:( +. ) in |
| 246 | + let probs = Array.map exp_logits ~f:(fun e -> e /. sum_exp) in |
| 247 | + |
| 248 | + (* CDF-based sampling *) |
| 249 | + let max_i = vocab_size - 1 in |
| 250 | + let rec sample i acc = |
| 251 | + if i >= max_i then i |
| 252 | + else |
| 253 | + let new_acc = acc +. probs.(i) in |
| 254 | + if Float.(new_acc > dice_value) then i |
| 255 | + else sample (i + 1) new_acc |
| 256 | + in |
| 257 | + let sampled_idx = sample 0 0. in |
| 258 | + let sampled_char = List.nth_exn Dataprep.Names.letters_with_dot sampled_idx in |
| 259 | + |
| 260 | + if Char.equal sampled_char '.' && pos > 1 then begin |
| 261 | + (* EOS — extract name from positions 1..pos-1 *) |
| 262 | + let name = Buffer.create 16 in |
| 263 | + for i = 1 to pos - 1 do |
| 264 | + let c = List.nth_exn Dataprep.Names.letters_with_dot context.(i) in |
| 265 | + if not (Char.equal c '.' || Char.equal c ' ') then Buffer.add_char name c |
| 266 | + done; |
| 267 | + Buffer.contents name |
| 268 | + end |
| 269 | + else begin |
| 270 | + context.(pos) <- sampled_idx; |
| 271 | + aux (pos + 1) |
| 272 | + end |
| 273 | + end |
| 274 | + in |
| 275 | + aux 1 |
| 276 | + in |
| 277 | + |
| 278 | + (* Generate very few names because different hardware backends diverge quickly. *) |
| 279 | + let names = Array.init 3 ~f:(fun _ -> gen_name ()) in |
| 280 | + Array.iter names ~f:print_endline |
0 commit comments