@@ -4,6 +4,8 @@ open RatOps
44
55open Batteries
66
7+ module M = Messages
8+
79let timing_wrap = Vector. timing_wrap
810
911module type SparseMatrix =
1618 val rref_vec : t -> vec -> t Option .t
1719
1820 val rref_matrix : t -> t -> t Option .t
21+
22+ val linear_disjunct : t -> t -> t
23+ (* * [linear_disjunct m1 m2] returns a matrix that contains the linear disjunct of [m1] and [m2].
24+ The result is in rref. If [m1] and [m2] are not linearly disjunct, an exception is raised. *)
1925end
2026
2127module type SparseMatrixFunctor =
@@ -296,29 +302,34 @@ module ListMatrix: SparseMatrixFunctor =
296302 @param v A vector with number of entries equal to the number of columns of [v].
297303 *)
298304 let rref_vec m v =
299- if is_empty m then (* In this case, v is normalized and returned *)
300- BatOption. map (fun (_ , value ) -> init_with_vec @@ div_row v value) (V. find_first_non_zero v)
301- else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
302- let pivot_positions = get_pivot_positions m in
303- (* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
304- let filtered_pivots = List. rev @@ fst @@ List. fold_left (fun (res , pivs_tail ) (col_idx , value ) ->
305- let pivs_tail = List. drop_while (fun (_ , piv_col , _ ) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
306- match pivs_tail with
307- | [] -> (res, [] )
308- | (row_idx , piv_col , row ) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
309- | _ -> (res, pivs_tail)
310- ) ([] , pivot_positions) (V. to_sparse_list v) in
311- let v_after_elim = List. fold_left (fun acc (row_idx , pivot_position , piv_row , v_at_piv ) ->
312- sub_scaled_row acc piv_row v_at_piv
313- ) v filtered_pivots in
314- match V. find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
315- | None -> Some m (* v is zero vector and was therefore already covered by m *)
316- | Some (idx , value ) ->
317- if idx = (num_cols m - 1 ) then
318- None
319- else
320- let normalized_v = V. map_f_preserves_zero (fun x -> x /: value) v_after_elim in
321- Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
305+ let res =
306+ if is_empty m then (* In this case, v is normalized and returned *)
307+ BatOption. map (fun (_ , value ) -> init_with_vec @@ div_row v value) (V. find_first_non_zero v)
308+ else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
309+ let pivot_positions = get_pivot_positions m in
310+ (* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
311+ let filtered_pivots = List. rev @@ fst @@ List. fold_left (fun (res , pivs_tail ) (col_idx , value ) ->
312+ let pivs_tail = List. drop_while (fun (_ , piv_col , _ ) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
313+ match pivs_tail with
314+ | [] -> (res, [] )
315+ | (row_idx , piv_col , row ) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
316+ | _ -> (res, pivs_tail)
317+ ) ([] , pivot_positions) (V. to_sparse_list v)
318+ in
319+ let v_after_elim = List. fold_left (fun acc (row_idx , pivot_position , piv_row , v_at_piv ) ->
320+ sub_scaled_row acc piv_row v_at_piv
321+ ) v filtered_pivots in
322+ match V. find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
323+ | None -> Some m (* v is zero vector and was therefore already covered by m *)
324+ | Some (idx , value ) ->
325+ if idx = (num_cols m - 1 ) then
326+ None
327+ else
328+ let normalized_v = V. map_f_preserves_zero (fun x -> x /: value) v_after_elim in
329+ Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
330+ in
331+ if M. tracing then M. trace " rref_vec" " rref_vec: m:\n %s, v: %s => res:\n %s" (show m) (V. show v) (match res with None -> " None" | Some r -> show r);
332+ res
322333
323334 let rref_vec m v = timing_wrap " rref_vec" (rref_vec m) v
324335
@@ -372,4 +383,129 @@ module ListMatrix: SparseMatrixFunctor =
372383
373384 let is_covered_by m1 m2 = timing_wrap " is_covered_by" (is_covered_by m1) m2
374385
386+ (* * Direct implementation of https://doi.org/10.1007/BF00268497 , chapter 5.2 Calculation of Linear Disjunction
387+ also available at https://www-apr.lip6.fr/~mine/enseignement/mpri/attic/2014-2015/exos/karr.pdf
388+ only difference is the implementation being optimized for sparse matrices in row representation
389+ *)
390+ let linear_disjunct m1 m2 =
391+ let maxcols = num_cols m1 in
392+ let inverse_termorder = fun x y -> y - x in
393+ let rev_matrix = List. map (fun x -> V. of_sparse_list (V. length x) (List. rev @@ V. to_sparse_list x) ) in
394+ let del_col m i = List. map (fun v -> V. tail_afterindex v i) m in
395+ let safe_get_row m i =
396+ try List. nth m i with
397+ | Invalid_argument _ -> V. zero_vec (num_cols m) (* if row is empty, we return zero *)
398+ in
399+ let safe_remove_row m i =
400+ try remove_row m i with (* remove_row can fail for sparse representations *)
401+ | Invalid_argument _ -> m (* if row is empty, we return the original matrix *)
402+ in
403+
404+ let col_and_rc m colidx rowidx =
405+ let col = get_col_upper_triangular m colidx in
406+ let rc = try V. nth col rowidx with (* V.nth could be integrated into get_col for the last few bits of performance... *)
407+ | Invalid_argument _ -> A. zero (* if col is empty, we return zero *) in
408+ col, rc
409+ in
410+
411+ let push_col m colidx col =
412+ List. mapi (fun idx row ->
413+ match V. nth col idx with
414+ | valu when A. equal A. zero valu -> row (* if the value is zero, we do not change the row *)
415+ | valu -> V. push_first row colidx valu
416+ | exception _ -> row
417+ ) m
418+ in
419+
420+ let case_two a r col_b =
421+ let a_r = get_row a r in
422+ let res = map2i (fun i x y -> if i < r then
423+ V. map2_f_preserves_zero (fun u j -> u +: y *: j) x a_r
424+ else x) a col_b in
425+ if M. tracing then M. trace " linear_disjunct_cases" " case_two: \n a:\n %s, r:%d,\n col_b: %s, a_r: %s, => res:\n %s" (show a) r (V. show col_b) (V. show a_r) (show res);
426+ res
427+ in
428+
429+ let case_three col1 col2 m1 m2 result ridx cidx = (* no new pivots at ridx/cidx *)
430+ let sub_and_lastterm c1 c2 = (* return last element/idx pair that differs*)
431+ let len = V. length c1 in
432+ let c1 = V. to_sparse_list c1 in
433+ let c2 = V. to_sparse_list c2 in
434+ let rec sub_and_last_aux (acclist ,acc ) c1 c2 =
435+ match c1, c2 with
436+ | (i ,_ )::_ ,_ when i > = ridx -> (acclist,acc) (* we are done, no more entries in c1 that are relevant *)
437+ | (i1 , v1 ) :: xs1 , (i2 , v2 ) :: xs2 when i1 = i2 ->
438+ let res = A. sub v1 v2 in
439+ let acc = if A. equal res A. zero then acc else Some (i1, v1, v2) in
440+ sub_and_last_aux ((i1,res)::acclist,acc) xs1 xs2
441+ | (i1 , v1 ) :: xs1 , (i2 , v2 ) :: xs2 when i1 < i2 -> sub_and_last_aux ((i1,v1)::acclist,Some (i1,v1,A. zero)) xs1 ((i2, v2)::xs2)
442+ | (i1 , v1 ) :: xs1 , (i2 , v2 ) :: xs2 (* when i1 > i2 *) -> sub_and_last_aux ((i2,A. neg v2)::acclist,Some (i2,A. zero,v2)) ((i1, v1)::xs1) xs2
443+ | (i ,v )::xs ,[] -> sub_and_last_aux ((i,v)::acclist,Some (i,v,A. zero)) xs []
444+ | [] , (i ,v )::xs -> sub_and_last_aux ((i,v)::acclist,Some (i,A. zero,v)) [] xs
445+ | [] , [] -> (acclist,acc)
446+ in
447+ let resl,rest = sub_and_last_aux ([] ,None ) c1 c2 in
448+ if M. tracing then M. trace " linear_disjunct_cases" " sub_and_last: ridx: %d c1: %s, c2: %s, resultlist: %s, result_pivot: %s" ridx (V. show col1) (V. show col2) (String. concat " ," (List. map (fun (i ,v ) -> Printf. sprintf " (%d,%s)" i (A. to_string v)) resl)) (match rest with None -> " None" | Some (i ,v1 ,v2 ) -> Printf. sprintf " (%d,%s,%s)" i (A. to_string v1) (A. to_string v2));
449+ V. of_sparse_list len (List. rev resl), rest
450+ in
451+ let coldiff,lastdiff = sub_and_lastterm col1 col2 in
452+ match lastdiff with
453+ | None ->
454+ let sameinboth= get_col_upper_triangular m1 cidx in
455+ if M. tracing then M. trace " linear_disjunct_cases" " case_three: no difference found, cidx: %d, ridx: %d, coldiff: %s, sameinboth: %s" cidx ridx (V. show coldiff) (V. show sameinboth);
456+ (del_col m1 cidx, del_col m2 cidx, push_col result cidx sameinboth, ridx) (* No difference found -> (del_col m1 cidx, del_col m2 cidx, push hd to result, ridx)*)
457+ | Some (idx ,x ,y ) ->
458+ let r1 = safe_get_row m1 idx in
459+ let r2 = safe_get_row m2 idx in
460+ let resrow = safe_get_row result idx in
461+ let diff = x -: y in
462+ let multiply_by_t termorder m t =
463+ map2i (fun i x c -> if i < = ridx then
464+ let beta = c /: diff in
465+ V. map2_f_preserves_zero_helper (termorder) (fun u j -> u -: (beta *: j)) x t
466+ else x) m coldiff
467+ in
468+ let transformed_res = multiply_by_t (inverse_termorder) result resrow in
469+ let transformed_a = multiply_by_t (- ) m1 r1 in
470+ let alpha = get_col_upper_triangular transformed_a cidx in
471+ let res = push_col transformed_res cidx alpha in
472+ if M. tracing then M. trace " linear_disjunct_cases" " case_three: found difference at ridx: %d idx: %d, x: %s, y: %s, diff: %s, m1: \n %s, m2:\n %s, res:\n %s"
473+ ridx idx (A. to_string x) (A. to_string y) (A. to_string diff) (show m1) (show m2) (show @@ rev_matrix res);
474+ safe_remove_row (transformed_a) idx, safe_remove_row (multiply_by_t (- ) m2 r2) idx, safe_remove_row (res) idx, ridx - 1
475+ in
476+
477+ let rec lindisjunc_aux currentrowindex currentcolindex m1 m2 result =
478+ if M. tracing then M. trace " linear_disjunct" " result so far: \n %s, currentrowindex: %d, currentcolindex: %d, m1: \n %s, m2:\n %s "
479+ (show @@ rev_matrix result) currentrowindex currentcolindex (show m1) (show m2);
480+ if currentcolindex > = maxcols then result
481+ else
482+ let col1, rc1 = col_and_rc m1 currentcolindex currentrowindex in
483+ let col2, rc2 = col_and_rc m2 currentcolindex currentrowindex in
484+ match Z. to_int @@ A. get_num rc1, Z. to_int @@ A. get_num rc2 with
485+ | 1 , 1 -> lindisjunc_aux
486+ (currentrowindex + 1 ) (currentcolindex+ 1 )
487+ (del_col m1 currentrowindex) (del_col m2 currentrowindex)
488+ (List. mapi (fun idx row -> if idx = currentrowindex then V. push_first row currentcolindex A. one else row) result)
489+ | 1 , 0 -> let beta = get_col_upper_triangular m2 currentcolindex in
490+ if M. tracing then M. trace " linear_disjunct_cases" " case 1,0: currentrowindex: %d, currentcolindex: %d, m1: \n %s, m2:\n %s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V. show beta);
491+ lindisjunc_aux
492+ (currentrowindex) (currentcolindex+ 1 )
493+ (safe_remove_row (case_two m1 currentrowindex col2) currentrowindex) (safe_remove_row m2 currentrowindex)
494+ (safe_remove_row (push_col result currentcolindex beta) currentrowindex)
495+ | 0 , 1 -> let beta = get_col_upper_triangular m1 currentcolindex in
496+ if M. tracing then M. trace " linear_disjunct_cases" " case 0,1: currentrowindex: %d, currentcolindex: %d, m1: \n %s, m2:\n %s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V. show beta);
497+ lindisjunc_aux
498+ (currentrowindex) (currentcolindex+ 1 )
499+ (safe_remove_row m1 currentrowindex) (safe_remove_row (case_two m2 currentrowindex col1) currentrowindex)
500+ (safe_remove_row (push_col result currentcolindex beta) currentrowindex)
501+ | 0 , 0 -> let m1 , m2 , result , currentrowindex = case_three col1 col2 m1 m2 result currentrowindex currentcolindex in
502+ lindisjunc_aux currentrowindex (currentcolindex+ 1 ) m1 m2 result (* we need to process m1, m2 and result *)
503+ | a ,b -> failwith (" matrix not in rref m1: " ^ (string_of_int a) ^ (string_of_int b)^ (show m1) ^ " m2: " ^ (show m2))
504+ in
505+ (* create a totally empty intial result, with dimensions rows x cols *)
506+ let pseudoempty = BatList. make (max (num_rows m1) (num_rows m1)) (V. zero_vec (num_cols m1)) in
507+ let res = rev_matrix @@ lindisjunc_aux 0 0 m1 m2 pseudoempty in
508+ if M. tracing then M. tracel " linear_disjunct" " linear_disjunct between \n %s and \n %s =>\n %s" (show m1) (show m2) (show res);
509+ res
510+
375511 end
0 commit comments