@@ -38,22 +38,50 @@ module ListMatrix: AbstractMatrix =
3838 let copy m =
3939 Timing. wrap " copy" (copy) m
4040
41- let add_empty_columns m cols =
42- let () = Printf. printf " Before add_empty_columns m:\n %sindices: %s\n " (show m) (Array. fold_right (fun x s -> s ^ (Int. to_string x) ^ " ," ) cols " " ) in
41+ (* This uses the logic from ArrayMatrix to add empty_columns*)
42+ let add_empty_columns (m : t ) (cols : int array ) : t =
43+ let nr = num_rows m in
44+ let nc = if nr = 0 then 0 else num_cols m in
45+ let nnc = Array. length cols in
46+
47+ (* If no rows or no new columns, just return m as is *)
48+ if nr = 0 || nnc = 0 then m
49+ else
50+ (* Create a new list of rows with additional columns. *)
51+ let new_rows =
52+ List. map (fun row ->
53+ let old_row_arr = V. to_array row in
54+ let new_row_arr = Array. make (nc + nnc) A. zero in
55+
56+ let offset = ref 0 in
57+ for j = 0 to nc - 1 do
58+ (* Check if we need to insert zero columns before placing old_row_arr.(j) *)
59+ while ! offset < nnc && ! offset + j = cols.(! offset) do
60+ incr offset
61+ done ;
62+ new_row_arr.(j + ! offset) < - old_row_arr.(j);
63+ done ;
64+
65+ V. of_array new_row_arr
66+ ) m
67+ in
68+
69+ new_rows
70+ (* let () = Printf.printf "Before add_empty_columns m:\n%sindices: %s\n" (show m) (Array.fold_right (fun x s -> (Int.to_string x) ^ "," ^ s) cols "") in
4371 let cols = Array.to_list cols in
4472 let sorted_cols = List.sort Stdlib.compare cols in
4573 let rec count_sorted_occ acc cols last count =
46- match cols with
47- | [] -> if count > 0 then (last, count) :: acc else acc
48- | x :: xs when x = last -> count_sorted_occ acc xs x (count + 1 )
49- | x :: xs -> let acc = if count > 0 then (last, count) :: acc else acc in
50- count_sorted_occ acc xs x 1
74+ match cols with
75+ | [] -> if count > 0 then (last, count) :: acc else acc
76+ | x :: xs when x = last -> count_sorted_occ acc xs x (count + 1)
77+ | x :: xs -> let acc = if count > 0 then (last, count) :: acc else acc in
78+ count_sorted_occ acc xs x 1
5179 in
5280 let occ_cols = List.rev @@ count_sorted_occ [] sorted_cols 0 0 in
5381 (* let () = Printf.printf "sorted cols is: %s\n" (List.fold_right (fun x s -> (Int.to_string x) ^ s) sorted_cols "") in
54- let () = Printf.printf "sorted_occ is: %s\n" (List.fold_right (fun (i, count) s -> "(" ^ (Int.to_string i) ^ "," ^ (Int.to_string count) ^ ")" ^ s) occ_cols "") in*)
82+ let () = Printf.printf "sorted_occ is: %s\n" (List.fold_right (fun (i, count) s -> "(" ^ (Int.to_string i) ^ "," ^ (Int.to_string count) ^ ")" ^ s) occ_cols "") in*)
5583 let () = Printf.printf "After add_empty_columns m:\n%s\n" (show (List.map (fun row -> V.insert_zero_at_indices row occ_cols (List.length cols)) m)) in
56- List. map (fun row -> V. insert_zero_at_indices row occ_cols (List. length cols)) m
84+ List.map (fun row -> V.insert_zero_at_indices row occ_cols (List.length cols)) m*)
5785
5886 let add_empty_columns m cols =
5987 Timing. wrap " add_empty_cols" (add_empty_columns m) cols
@@ -220,7 +248,7 @@ module ListMatrix: AbstractMatrix =
220248 in
221249 List. for_all row_is_valid m in
222250 let rec main_loop m m' row_idx col_idx =
223- if col_idx = (col_count - 2 ) then m (* In this case the whole bottom of the matrix starting from row_index is Zero, so it is normalized *)
251+ if col_idx > = (col_count - 1 ) then m (* In this case the whole bottom of the matrix starting from row_index is Zero, so it is normalized *)
224252 else
225253 match find_first_pivot m' row_idx col_idx with
226254 | None -> m (* No pivot found means already normalized*)
@@ -247,7 +275,7 @@ module ListMatrix: AbstractMatrix =
247275 let rref_vec m v =
248276 let () = Printf. printf " Before rref_vec we have m:\n %sv: %s\n " (show m) (V. show v) in
249277 match normalize @@ append_matrices m (init_with_vec v) with
250- | Some res -> let () = Printf. printf " After rref_vec we have m:\n %s\n " (show res) in
278+ | Some res -> let () = Printf. printf " After rref_vec, before removing zero rows, we have m:\n %s\n " (show res) in
251279 Some (remove_zero_rows res)
252280 | None -> let () = Printf. printf " After rref_vec there is no normalization\n " in None
253281
@@ -256,7 +284,9 @@ module ListMatrix: AbstractMatrix =
256284 (* TODO: OPTIMIZE!*)
257285 let rref_matrix m1 m2 =
258286 let () = Printf. printf " Before rref_matrix m1 m2\n m1: %s\n m2: %s\n " (show m1) (show m2) in
259- normalize @@ append_matrices m1 m2
287+ match normalize @@ append_matrices m1 m2 with
288+ | Some m -> let () = Printf. printf " After rref_matrix m, before removing zero rows:\n %s\n " (show m) in Some (remove_zero_rows m)
289+ | None -> let () = Printf. printf " No normalization for rref_matrix found" in None
260290
261291
262292 let delete_row_with_pivots row pivots m2 =
0 commit comments