-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmap_fold.ml
149 lines (133 loc) · 4.77 KB
/
map_fold.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
(**************************************************************************)
(* *)
(* Functory: a distributed computing library for Ocaml *)
(* Copyright (C) 2010 Jean-Christophe Filliatre and Kalyan Krishnamani *)
(* *)
(* This software is free software; you can redistribute it and/or *)
(* modify it under the terms of the GNU Library General Public *)
(* License version 2.1, with the special exception on linking *)
(* described in file LICENSE. *)
(* *)
(* This software is distributed in the hope that it will be useful, *)
(* but WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. *)
(* *)
(**************************************************************************)
type ('a, 'b) map_or_fold =
| Map of 'a
| Fold of 'b
let map_fold_wrapper map fold = function
| Map x -> Map (map x)
| Fold (x, y) -> Fold (fold x y)
let map_fold_wrapper2 map fold = function
| Map x -> map x
| Fold (x, y) -> fold x y
module Make
(X : sig
val compute :
worker:('a -> 'b) ->
master:('a * 'c -> 'b -> ('a * 'c) list) ->
('a * 'c) list ->
unit
end) :
sig
val map : f:('a -> 'b) -> 'a list -> 'b list
val map_local_fold :
f:('a -> 'b) -> fold:('c -> 'b -> 'c) -> 'c -> 'a list -> 'c
val map_remote_fold :
f:('a -> 'b) -> fold:('c -> 'b -> 'c) -> 'c -> 'a list -> 'c
val map_fold_ac :
f:('a -> 'b) -> fold:('b -> 'b -> 'b) -> 'b -> 'a list -> 'b
val map_fold_a :
f:('a -> 'b) -> fold:('b -> 'b -> 'b) -> 'b -> 'a list -> 'b
end = struct
let map ~f l =
let tasks = let i = ref 0 in List.map (fun x -> incr i; x, !i) l in
let results = Hashtbl.create 17 in (* index -> 'b *)
X.compute
~worker:f
~master:(fun (_,i) r -> Hashtbl.add results i r; [])
tasks;
List.map (fun (_,i) -> Hashtbl.find results i) tasks
let map_local_fold ~(f : 'a -> 'b) ~(fold : 'c -> 'b -> 'c) acc l =
let acc = ref acc in
X.compute
~worker:f
~master:(fun _ r -> acc := fold !acc r; [])
(List.map (fun x -> x, ()) l);
!acc
let map_remote_fold ~(f : 'a -> 'b) ~(fold : 'c -> 'b -> 'c) acc l =
let acc = ref (Some acc) in
let pending = Stack.create () in
X.compute
~worker:(map_fold_wrapper f fold)
~master:(fun _ r -> match r with
| Map r -> begin match !acc with
| None -> Stack.push r pending; []
| Some v -> acc := None; [Fold (v, r), ()]
end
| Fold r ->
assert (!acc = None);
if not (Stack.is_empty pending) then
[Fold (r, Stack.pop pending), ()]
else begin
acc := Some r;
[]
end)
(List.map (fun x -> Map x, ()) l);
(* we are done; the accumulator must exist *)
match !acc with
| Some r -> r
| None -> assert false
let map_fold_ac ~(f : 'a -> 'b) ~(fold : 'b -> 'b -> 'b) acc l =
let acc = ref (Some acc) in
X.compute
~worker:(map_fold_wrapper2 f fold)
~master:(fun _ r ->
match !acc with
| None ->
acc := Some r; []
| Some v ->
acc := None;
[Fold (v, r), ()])
(List.map (fun x -> Map x, ()) l);
(* we are done; the accumulator must exist *)
match !acc with
| Some r -> r
| None -> assert false
let map_fold_a ~(f : 'a -> 'b) ~(fold : 'b -> 'b -> 'b) acc l =
let tasks =
let i = ref 0 in
List.map (fun x -> incr i; Map x, (!i, !i)) l
in
(* results maps i and j to (i,j,r) for each completed reduction of
the interval i..j with result r *)
let results = Hashtbl.create 17 in
let merge i j r =
if Hashtbl.mem results (i-1) then begin
let l, h, x = Hashtbl.find results (i-1) in
assert (h = i-1);
Hashtbl.remove results l;
Hashtbl.remove results h;
[Fold (x, r), (l, j)]
end else if Hashtbl.mem results (j+1) then begin
let l, h, x = Hashtbl.find results (j+1) in
assert (l = j+1);
Hashtbl.remove results h;
Hashtbl.remove results l;
[Fold (r, x), (i, h)]
end else begin
Hashtbl.add results i (i,j,r);
Hashtbl.add results j (i,j,r);
[]
end
in
X.compute
~worker:(map_fold_wrapper2 f fold)
~master:(fun x r -> match x with
| Map _, (i, _) -> merge i i r
| Fold _, (i, j) -> merge i j r)
tasks;
(* we are done; results must contain 2 mappings only, for 1 and n *)
try let _,_,r = Hashtbl.find results 1 in r with Not_found -> acc
end