Skip to content

Commit 498424f

Browse files
Merge pull request #1773 from goblint/optimize_sparse_affeq_lindisjunc
affeq: optimize sparse affeq join
2 parents 90d8705 + a78f03f commit 498424f

File tree

3 files changed

+213
-73
lines changed

3 files changed

+213
-73
lines changed

src/cdomains/affineEquality/sparseImplementation/listMatrix.ml

Lines changed: 159 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ open RatOps
44

55
open Batteries
66

7+
module M = Messages
8+
79
let timing_wrap = Vector.timing_wrap
810

911
module type SparseMatrix =
@@ -16,6 +18,10 @@ sig
1618
val rref_vec: t -> vec -> t Option.t
1719

1820
val rref_matrix: t -> t -> t Option.t
21+
22+
val linear_disjunct: t -> t -> t
23+
(** [linear_disjunct m1 m2] returns a matrix that contains the linear disjunct of [m1] and [m2].
24+
The result is in rref. If [m1] and [m2] are not linearly disjunct, an exception is raised. *)
1925
end
2026

2127
module type SparseMatrixFunctor =
@@ -296,29 +302,34 @@ module ListMatrix: SparseMatrixFunctor =
296302
@param v A vector with number of entries equal to the number of columns of [v].
297303
*)
298304
let rref_vec m v =
299-
if is_empty m then (* In this case, v is normalized and returned *)
300-
BatOption.map (fun (_, value) -> init_with_vec @@ div_row v value) (V.find_first_non_zero v)
301-
else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
302-
let pivot_positions = get_pivot_positions m in
303-
(* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
304-
let filtered_pivots = List.rev @@ fst @@ List.fold_left (fun (res, pivs_tail) (col_idx, value) ->
305-
let pivs_tail = List.drop_while (fun (_, piv_col, _) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
306-
match pivs_tail with
307-
| [] -> (res, [])
308-
| (row_idx, piv_col, row) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
309-
| _ -> (res, pivs_tail)
310-
) ([], pivot_positions) (V.to_sparse_list v) in
311-
let v_after_elim = List.fold_left (fun acc (row_idx, pivot_position, piv_row, v_at_piv) ->
312-
sub_scaled_row acc piv_row v_at_piv
313-
) v filtered_pivots in
314-
match V.find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
315-
| None -> Some m (* v is zero vector and was therefore already covered by m *)
316-
| Some (idx, value) ->
317-
if idx = (num_cols m - 1) then
318-
None
319-
else
320-
let normalized_v = V.map_f_preserves_zero (fun x -> x /: value) v_after_elim in
321-
Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
305+
let res =
306+
if is_empty m then (* In this case, v is normalized and returned *)
307+
BatOption.map (fun (_, value) -> init_with_vec @@ div_row v value) (V.find_first_non_zero v)
308+
else (* We try to normalize v and check if a contradiction arises. If not, we insert v at the appropriate place in m (depending on the pivot) *)
309+
let pivot_positions = get_pivot_positions m in
310+
(* filtered_pivots are only the pivots which have a non-zero entry in the corresponding column of v. Only those are relevant to subtract from v *)
311+
let filtered_pivots = List.rev @@ fst @@ List.fold_left (fun (res, pivs_tail) (col_idx, value) ->
312+
let pivs_tail = List.drop_while (fun (_, piv_col, _) -> piv_col < col_idx) pivs_tail in (* Skipping until possible match of both cols *)
313+
match pivs_tail with
314+
| [] -> (res, [])
315+
| (row_idx, piv_col, row) :: ps when piv_col = col_idx -> ((row_idx, piv_col, row, value) :: res, ps)
316+
| _ -> (res, pivs_tail)
317+
) ([], pivot_positions) (V.to_sparse_list v)
318+
in
319+
let v_after_elim = List.fold_left (fun acc (row_idx, pivot_position, piv_row, v_at_piv) ->
320+
sub_scaled_row acc piv_row v_at_piv
321+
) v filtered_pivots in
322+
match V.find_first_non_zero v_after_elim with (* now we check for contradictions and finally insert v *)
323+
| None -> Some m (* v is zero vector and was therefore already covered by m *)
324+
| Some (idx, value) ->
325+
if idx = (num_cols m - 1) then
326+
None
327+
else
328+
let normalized_v = V.map_f_preserves_zero (fun x -> x /: value) v_after_elim in
329+
Some (insert_v_according_to_piv m normalized_v idx pivot_positions)
330+
in
331+
if M.tracing then M.trace "rref_vec" "rref_vec: m:\n%s, v: %s => res:\n%s" (show m) (V.show v) (match res with None -> "None" | Some r -> show r);
332+
res
322333

323334
let rref_vec m v = timing_wrap "rref_vec" (rref_vec m) v
324335

@@ -372,4 +383,129 @@ module ListMatrix: SparseMatrixFunctor =
372383

373384
let is_covered_by m1 m2 = timing_wrap "is_covered_by" (is_covered_by m1) m2
374385

386+
(** Direct implementation of https://doi.org/10.1007/BF00268497 , chapter 5.2 Calculation of Linear Disjunction
387+
also available at https://www-apr.lip6.fr/~mine/enseignement/mpri/attic/2014-2015/exos/karr.pdf
388+
only difference is the implementation being optimized for sparse matrices in row representation
389+
*)
390+
let linear_disjunct m1 m2 =
391+
let maxcols = num_cols m1 in
392+
let inverse_termorder = fun x y -> y - x in
393+
let rev_matrix = List.map (fun x -> V.of_sparse_list (V.length x) (List.rev @@ V.to_sparse_list x) ) in
394+
let del_col m i = List.map (fun v -> V.tail_afterindex v i) m in
395+
let safe_get_row m i =
396+
try List.nth m i with
397+
| Invalid_argument _ -> V.zero_vec (num_cols m) (* if row is empty, we return zero *)
398+
in
399+
let safe_remove_row m i =
400+
try remove_row m i with (* remove_row can fail for sparse representations *)
401+
| Invalid_argument _ -> m (* if row is empty, we return the original matrix *)
402+
in
403+
404+
let col_and_rc m colidx rowidx =
405+
let col = get_col_upper_triangular m colidx in
406+
let rc = try V.nth col rowidx with (* V.nth could be integrated into get_col for the last few bits of performance... *)
407+
| Invalid_argument _ -> A.zero (* if col is empty, we return zero *) in
408+
col, rc
409+
in
410+
411+
let push_col m colidx col =
412+
List.mapi (fun idx row ->
413+
match V.nth col idx with
414+
| valu when A.equal A.zero valu -> row (* if the value is zero, we do not change the row *)
415+
| valu -> V.push_first row colidx valu
416+
| exception _ -> row
417+
) m
418+
in
419+
420+
let case_two a r col_b =
421+
let a_r = get_row a r in
422+
let res = map2i (fun i x y -> if i < r then
423+
V.map2_f_preserves_zero (fun u j -> u +: y *: j) x a_r
424+
else x) a col_b in
425+
if M.tracing then M.trace "linear_disjunct_cases" "case_two: \na:\n%s, r:%d,\n col_b: %s, a_r: %s, => res:\n%s" (show a) r (V.show col_b) (V.show a_r) (show res);
426+
res
427+
in
428+
429+
let case_three col1 col2 m1 m2 result ridx cidx = (* no new pivots at ridx/cidx *)
430+
let sub_and_lastterm c1 c2 = (* return last element/idx pair that differs*)
431+
let len = V.length c1 in
432+
let c1 = V.to_sparse_list c1 in
433+
let c2 = V.to_sparse_list c2 in
434+
let rec sub_and_last_aux (acclist,acc) c1 c2 =
435+
match c1, c2 with
436+
| (i,_)::_,_ when i >= ridx -> (acclist,acc) (* we are done, no more entries in c1 that are relevant *)
437+
| (i1, v1) :: xs1, (i2, v2) :: xs2 when i1 = i2 ->
438+
let res = A.sub v1 v2 in
439+
let acc = if A.equal res A.zero then acc else Some (i1, v1, v2) in
440+
sub_and_last_aux ((i1,res)::acclist,acc) xs1 xs2
441+
| (i1, v1) :: xs1, (i2, v2) :: xs2 when i1 < i2 -> sub_and_last_aux ((i1,v1)::acclist,Some (i1,v1,A.zero)) xs1 ((i2, v2)::xs2)
442+
| (i1, v1) :: xs1, (i2, v2) :: xs2 (* when i1 > i2 *)-> sub_and_last_aux ((i2,A.neg v2)::acclist,Some (i2,A.zero,v2)) ((i1, v1)::xs1) xs2
443+
| (i,v)::xs ,[] -> sub_and_last_aux ((i,v)::acclist,Some (i,v,A.zero)) xs []
444+
| [], (i,v)::xs -> sub_and_last_aux ((i,v)::acclist,Some (i,A.zero,v)) [] xs
445+
| [], [] -> (acclist,acc)
446+
in
447+
let resl,rest = sub_and_last_aux ([],None) c1 c2 in
448+
if M.tracing then M.trace "linear_disjunct_cases" "sub_and_last: ridx: %d c1: %s, c2: %s, resultlist: %s, result_pivot: %s" ridx (V.show col1) (V.show col2) (String.concat "," (List.map (fun (i,v) -> Printf.sprintf "(%d,%s)" i (A.to_string v)) resl)) (match rest with None -> "None" | Some (i,v1,v2) -> Printf.sprintf "(%d,%s,%s)" i (A.to_string v1) (A.to_string v2));
449+
V.of_sparse_list len (List.rev resl), rest
450+
in
451+
let coldiff,lastdiff = sub_and_lastterm col1 col2 in
452+
match lastdiff with
453+
| None ->
454+
let sameinboth=get_col_upper_triangular m1 cidx in
455+
if M.tracing then M.trace "linear_disjunct_cases" "case_three: no difference found, cidx: %d, ridx: %d, coldiff: %s, sameinboth: %s" cidx ridx (V.show coldiff) (V.show sameinboth);
456+
(del_col m1 cidx, del_col m2 cidx, push_col result cidx sameinboth, ridx) (* No difference found -> (del_col m1 cidx, del_col m2 cidx, push hd to result, ridx)*)
457+
| Some (idx,x,y) ->
458+
let r1 = safe_get_row m1 idx in
459+
let r2 = safe_get_row m2 idx in
460+
let resrow = safe_get_row result idx in
461+
let diff = x -: y in
462+
let multiply_by_t termorder m t =
463+
map2i (fun i x c -> if i <= ridx then
464+
let beta = c /: diff in
465+
V.map2_f_preserves_zero_helper (termorder) (fun u j -> u -: (beta *: j)) x t
466+
else x) m coldiff
467+
in
468+
let transformed_res = multiply_by_t (inverse_termorder) result resrow in
469+
let transformed_a = multiply_by_t (-) m1 r1 in
470+
let alpha = get_col_upper_triangular transformed_a cidx in
471+
let res = push_col transformed_res cidx alpha in
472+
if M.tracing then M.trace "linear_disjunct_cases" "case_three: found difference at ridx: %d idx: %d, x: %s, y: %s, diff: %s, m1: \n%s, m2:\n%s, res:\n%s"
473+
ridx idx (A.to_string x) (A.to_string y) (A.to_string diff) (show m1) (show m2) (show @@ rev_matrix res);
474+
safe_remove_row (transformed_a) idx, safe_remove_row (multiply_by_t (-) m2 r2) idx, safe_remove_row (res) idx, ridx - 1
475+
in
476+
477+
let rec lindisjunc_aux currentrowindex currentcolindex m1 m2 result =
478+
if M.tracing then M.trace "linear_disjunct" "result so far: \n%s, currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s "
479+
(show @@ rev_matrix result) currentrowindex currentcolindex (show m1) (show m2);
480+
if currentcolindex >= maxcols then result
481+
else
482+
let col1, rc1 = col_and_rc m1 currentcolindex currentrowindex in
483+
let col2, rc2 = col_and_rc m2 currentcolindex currentrowindex in
484+
match Z.to_int @@ A.get_num rc1, Z.to_int @@ A.get_num rc2 with
485+
| 1, 1 -> lindisjunc_aux
486+
(currentrowindex + 1) (currentcolindex+1)
487+
(del_col m1 currentrowindex) (del_col m2 currentrowindex)
488+
(List.mapi (fun idx row -> if idx = currentrowindex then V.push_first row currentcolindex A.one else row) result)
489+
| 1, 0 -> let beta = get_col_upper_triangular m2 currentcolindex in
490+
if M.tracing then M.trace "linear_disjunct_cases" "case 1,0: currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V.show beta);
491+
lindisjunc_aux
492+
(currentrowindex) (currentcolindex+1)
493+
(safe_remove_row (case_two m1 currentrowindex col2) currentrowindex) (safe_remove_row m2 currentrowindex)
494+
(safe_remove_row (push_col result currentcolindex beta) currentrowindex)
495+
| 0, 1 -> let beta = get_col_upper_triangular m1 currentcolindex in
496+
if M.tracing then M.trace "linear_disjunct_cases" "case 0,1: currentrowindex: %d, currentcolindex: %d, m1: \n%s, m2:\n%s , beta %s" currentrowindex currentcolindex (show m1) (show m2) (V.show beta);
497+
lindisjunc_aux
498+
(currentrowindex) (currentcolindex+1)
499+
(safe_remove_row m1 currentrowindex) (safe_remove_row (case_two m2 currentrowindex col1) currentrowindex)
500+
(safe_remove_row (push_col result currentcolindex beta) currentrowindex)
501+
| 0, 0 -> let m1 , m2, result, currentrowindex = case_three col1 col2 m1 m2 result currentrowindex currentcolindex in
502+
lindisjunc_aux currentrowindex (currentcolindex+1) m1 m2 result (* we need to process m1, m2 and result *)
503+
| a,b -> failwith ("matrix not in rref m1: " ^ (string_of_int a) ^ (string_of_int b)^(show m1) ^ " m2: " ^ (show m2))
504+
in
505+
(* create a totally empty intial result, with dimensions rows x cols *)
506+
let pseudoempty = BatList.make (max (num_rows m1) (num_rows m1)) (V.zero_vec (num_cols m1)) in
507+
let res = rev_matrix @@ lindisjunc_aux 0 0 m1 m2 pseudoempty in
508+
if M.tracing then M.tracel "linear_disjunct" "linear_disjunct between \n%s and \n%s =>\n%s" (show m1) (show m2) (show res);
509+
res
510+
375511
end

src/cdomains/affineEquality/sparseImplementation/sparseVector.ml

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
open Vector
22
open RatOps
33

4+
module M = Messages
5+
46
open Batteries
57

68
module type SparseVector =
79
sig
810
include Vector
11+
val push_first: t -> int -> num -> t
12+
913
val is_zero_vec: t -> bool
1014

15+
val tail_afterindex: t -> int -> t
16+
1117
val insert_zero_at_indices: t -> (int * int) list -> int -> t
1218

1319
val remove_at_indices: t -> int list -> t
@@ -23,6 +29,8 @@ sig
2329

2430
val map2_f_preserves_zero: (num -> num -> num) -> t -> t -> t
2531

32+
val map2_f_preserves_zero_helper: (int -> int -> int) -> (num -> num -> num) -> t -> t -> t
33+
2634
val find2i_f_false_at_zero: (num -> num -> bool) -> t -> t -> int
2735

2836
val apply_with_c_f_preserves_zero: (num -> num -> num) -> num -> t -> t
@@ -86,7 +94,8 @@ module SparseVector: SparseVectorFunctor =
8694
let show v =
8795
let rec sparse_list_str i l =
8896
if i >= v.len then "]"
89-
else match l with
97+
else
98+
match l with
9099
| [] -> (A.to_string A.zero) ^" "^ (sparse_list_str (i + 1) l)
91100
| (idx, value) :: xs ->
92101
if i = idx then (A.to_string value) ^" "^ sparse_list_str (i + 1) xs
@@ -126,6 +135,13 @@ module SparseVector: SparseVectorFunctor =
126135
| Some (idx, value) when idx = n -> value
127136
| _ -> A.zero
128137

138+
let push_first v n num =
139+
if n >= v.len then raise (Invalid_argument "Index out of bounds")
140+
else let res =
141+
{v with entries = (n,num)::v.entries} in
142+
if M.tracing then M.trace "push_first" "pushed %s at index %d, new length: %d, resulting in %s" (A.to_string num) n res.len (res.entries |> List.map (fun (i, x) -> Printf.sprintf "(%d, %s)" i (A.to_string x)) |> String.concat ", ");
143+
res
144+
129145
(**
130146
[set_nth v n num] returns [v] where the [n]-th entry has been set to [num].
131147
@raise Invalid_argument if [n] is out of bounds.
@@ -156,6 +172,22 @@ module SparseVector: SparseVectorFunctor =
156172
in
157173
{entries = add_indices_helper v.entries indices 0 []; len = v.len + num_zeros}
158174

175+
(**
176+
[tail_afterindex v n] returns the vector starting after the [n]-th entry, i.e. all entries with index > [n].
177+
@raise Invalid_argument if [n] is out of bounds.
178+
*)
179+
let tail_afterindex v n =
180+
if n >= v.len then raise (Invalid_argument "Index out of bounds")
181+
else
182+
match v.entries with
183+
| [] -> v (* If the vector is empty, return it as is *)
184+
| (headidx,headval) :: _ ->
185+
if M.tracing then M.trace "tail_afterindex" "headidx: %d, n: %d, v.len: %d" headidx n v.len;
186+
if headidx > n then v
187+
else
188+
let entries = List.tl v.entries in
189+
{entries; len = v.len }
190+
159191
(**
160192
[remove_nth v n] returns [v] where the [n]-th entry is removed, decreasing the length of the vector by one.
161193
@raise Invalid_argument if [n] is out of bounds
@@ -257,13 +289,15 @@ module SparseVector: SparseVectorFunctor =
257289
{v with entries = entries'}
258290

259291
(**
260-
[map2_f_preserves_zero f v v'] returns the mapping of [v] and [v'] specified by [f].
292+
[map2_f_preserves_zero termorder f v v'] returns the mapping of [v] and [v'] specified by [f].
261293
262294
Note that [f] {b must} be such that [f 0 0 = 0]!
263295
296+
[termorder] is a function specifying, if the entries of [v] and [v'] are ordered in increasing or decreasing index order.
297+
264298
@raise Invalid_argument if [v] and [v'] have unequal lengths
265299
*)
266-
let map2_f_preserves_zero f v v' =
300+
let map2_f_preserves_zero_helper termorder f v v' =
267301
let f_rem_zero acc idx e1 e2 =
268302
let r = f e1 e2 in
269303
if r =: A.zero then acc else (idx, r) :: acc
@@ -274,14 +308,25 @@ module SparseVector: SparseVectorFunctor =
274308
| [], (yidx, yval) :: ys -> aux (f_rem_zero acc yidx A.zero yval) [] ys
275309
| (xidx, xval) :: xs, [] -> aux (f_rem_zero acc xidx xval A.zero) xs []
276310
| (xidx, xval) :: xs, (yidx, yval) :: ys ->
277-
match xidx - yidx with
311+
match termorder xidx yidx with
278312
| d when d < 0 -> aux (f_rem_zero acc xidx xval A.zero) xs v2
279313
| d when d > 0 -> aux (f_rem_zero acc yidx A.zero yval) v1 ys
280314
| _ -> aux (f_rem_zero acc xidx xval yval) xs ys
281315
in
282316
if v.len <> v'.len then raise (Invalid_argument "Unequal lengths") else
283317
{v with entries = List.rev (aux [] v.entries v'.entries)}
284318

319+
(**
320+
[map2_f_preserves_zero f v v'] returns the mapping of [v] and [v'] specified by [f].
321+
322+
Note that [f] {b must} be such that [f 0 0 = 0]!
323+
324+
The entries of [v] and [v'] are assumed to be ordered in increasing index order.
325+
326+
@raise Invalid_argument if [v] and [v'] have unequal lengths
327+
*)
328+
let map2_f_preserves_zero f v v'= map2_f_preserves_zero_helper (-) f v v'
329+
285330
let map2_f_preserves_zero f v1 v2 = timing_wrap "map2_f_preserves_zero" (map2_f_preserves_zero f v1) v2
286331

287332
(**
@@ -321,6 +366,5 @@ module SparseVector: SparseVectorFunctor =
321366

322367
let rev v =
323368
let entries = List.rev_map (fun (idx, value) -> (v.len - 1 - idx, value)) v.entries in
324-
{entries; len = v.len}
325-
369+
{entries; len = v.len}
326370
end

0 commit comments

Comments
 (0)