@@ -5,17 +5,17 @@ use std::{
5
5
6
6
use itertools:: Itertools ;
7
7
use powdr_ast:: {
8
- analyzed:: {
9
- AlgebraicExpression as Expression , AlgebraicReference , AlgebraicReferenceThin ,
10
- PolynomialType ,
11
- } ,
8
+ analyzed:: { AlgebraicExpression as Expression , AlgebraicReferenceThin , PolynomialType } ,
12
9
parsed:: visitor:: { AllChildren , Children } ,
13
10
} ;
14
11
use powdr_number:: FieldElement ;
15
12
16
13
use crate :: witgen:: { data_structures:: identity:: Identity , FixedData } ;
17
14
18
- use super :: variable:: Variable ;
15
+ use super :: {
16
+ variable:: Variable ,
17
+ witgen_inference:: { Assignment , VariableOrValue } ,
18
+ } ;
19
19
20
20
/// Keeps track of identities that still need to be processed and
21
21
/// updates this list based on the occurrence of updated variables
@@ -27,98 +27,120 @@ pub struct IdentityQueue<'a, T: FieldElement> {
27
27
}
28
28
29
29
impl < ' a , T : FieldElement > IdentityQueue < ' a , T > {
30
- pub fn new ( fixed_data : & ' a FixedData < ' a , T > , identities : & [ ( & ' a Identity < T > , i32 ) ] ) -> Self {
31
- let occurrences = compute_occurrences_map ( fixed_data, identities) . into ( ) ;
32
- Self {
33
- queue : identities
34
- . iter ( )
35
- . map ( |( id, row) | QueueItem ( id, * row) )
36
- . collect ( ) ,
37
- occurrences,
38
- }
30
+ pub fn new (
31
+ fixed_data : & ' a FixedData < ' a , T > ,
32
+ identities : & [ ( & ' a Identity < T > , i32 ) ] ,
33
+ assignments : & [ Assignment < ' a , T > ] ,
34
+ ) -> Self {
35
+ let queue: BTreeSet < _ > = identities
36
+ . iter ( )
37
+ . map ( |( id, row) | QueueItem :: Identity ( id, * row) )
38
+ . chain ( assignments. iter ( ) . map ( |a| QueueItem :: Assignment ( a. clone ( ) ) ) )
39
+ . collect ( ) ;
40
+ let occurrences = compute_occurrences_map ( fixed_data, & queue) . into ( ) ;
41
+ Self { queue, occurrences }
39
42
}
40
43
41
44
/// Returns the next identity to be processed and its row and
42
45
/// removes it from the queue.
43
- pub fn next ( & mut self ) -> Option < ( & ' a Identity < T > , i32 ) > {
44
- self . queue . pop_first ( ) . map ( | QueueItem ( id , row ) | ( id , row ) )
46
+ pub fn next ( & mut self ) -> Option < QueueItem < ' a , T > > {
47
+ self . queue . pop_first ( )
45
48
}
46
49
47
50
pub fn variables_updated (
48
51
& mut self ,
49
52
variables : impl IntoIterator < Item = Variable > ,
50
- skip_identity : Option < ( & ' a Identity < T > , i32 ) > ,
53
+ skip_item : Option < QueueItem < ' a , T > > ,
51
54
) {
52
55
self . queue . extend (
53
56
variables
54
57
. into_iter ( )
55
58
. flat_map ( |var| self . occurrences . get ( & var) )
56
59
. flatten ( )
57
- . filter ( |QueueItem ( id , row ) | match skip_identity {
58
- Some ( ( id2 , row2 ) ) => ( id . id ( ) , * row ) != ( id2 . id ( ) , row2 ) ,
60
+ . filter ( |item | match & skip_item {
61
+ Some ( it ) => * item != it ,
59
62
None => true ,
60
- } ) ,
63
+ } )
64
+ . cloned ( ) ,
61
65
)
62
66
}
63
67
}
64
68
65
- /// Sorts identities by row and then by ID.
66
- #[ derive( Clone , Copy ) ]
67
- struct QueueItem < ' a , T > ( & ' a Identity < T > , i32 ) ;
68
-
69
- impl < T > QueueItem < ' _ , T > {
70
- fn key ( & self ) -> ( i32 , u64 ) {
71
- let QueueItem ( id, row) = self ;
72
- ( * row, id. id ( ) )
73
- }
69
+ #[ derive( Clone ) ]
70
+ pub enum QueueItem < ' a , T : FieldElement > {
71
+ Identity ( & ' a Identity < T > , i32 ) ,
72
+ Assignment ( Assignment < ' a , T > ) ,
74
73
}
75
74
76
- impl < T > Ord for QueueItem < ' _ , T > {
75
+ /// Sorts identities by row and then by ID, preceded by assignments.
76
+ impl < T : FieldElement > Ord for QueueItem < ' _ , T > {
77
77
fn cmp ( & self , other : & Self ) -> std:: cmp:: Ordering {
78
- self . key ( ) . cmp ( & other. key ( ) )
78
+ match ( self , other) {
79
+ ( QueueItem :: Identity ( id1, row1) , QueueItem :: Identity ( id2, row2) ) => {
80
+ ( row1, id1. id ( ) ) . cmp ( & ( row2, id2. id ( ) ) )
81
+ }
82
+ ( QueueItem :: Assignment ( a1) , QueueItem :: Assignment ( a2) ) => a1. cmp ( a2) ,
83
+ ( QueueItem :: Assignment ( _) , QueueItem :: Identity ( _, _) ) => std:: cmp:: Ordering :: Less ,
84
+ ( QueueItem :: Identity ( _, _) , QueueItem :: Assignment ( _) ) => std:: cmp:: Ordering :: Greater ,
85
+ }
79
86
}
80
87
}
81
88
82
- impl < T > PartialOrd for QueueItem < ' _ , T > {
89
+ impl < T : FieldElement > PartialOrd for QueueItem < ' _ , T > {
83
90
fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
84
91
Some ( self . cmp ( other) )
85
92
}
86
93
}
87
94
88
- impl < T > PartialEq for QueueItem < ' _ , T > {
95
+ impl < T : FieldElement > PartialEq for QueueItem < ' _ , T > {
89
96
fn eq ( & self , other : & Self ) -> bool {
90
- self . key ( ) == other . key ( )
97
+ self . cmp ( other ) == std :: cmp :: Ordering :: Equal
91
98
}
92
99
}
93
100
94
- impl < T > Eq for QueueItem < ' _ , T > { }
101
+ impl < T : FieldElement > Eq for QueueItem < ' _ , T > { }
95
102
96
- /// Computes a map from each variable to the identity-row-offset pairs it occurs in.
97
- fn compute_occurrences_map < ' a , T : FieldElement > (
103
+ /// Computes a map from each variable to the queue items it occurs in.
104
+ fn compute_occurrences_map < ' b , ' a : ' b , T : FieldElement > (
98
105
fixed_data : & ' a FixedData < ' a , T > ,
99
- identities : & [ ( & ' a Identity < T > , i32 ) ] ,
106
+ items : & BTreeSet < QueueItem < ' a , T > > ,
100
107
) -> HashMap < Variable , Vec < QueueItem < ' a , T > > > {
101
- let mut references_per_identity = HashMap :: new ( ) ;
102
108
let mut intermediate_cache = HashMap :: new ( ) ;
103
- for id in identities. iter ( ) . map ( |( id, _) | * id) . unique_by ( |id| id. id ( ) ) {
109
+
110
+ // Compute references only once per identity.
111
+ let mut references_per_identity = HashMap :: new ( ) ;
112
+ for id in items
113
+ . iter ( )
114
+ . filter_map ( |item| match item {
115
+ QueueItem :: Identity ( id, _) => Some ( id) ,
116
+ _ => None ,
117
+ } )
118
+ . unique_by ( |id| id. id ( ) )
119
+ {
104
120
references_per_identity. insert (
105
- id,
121
+ id. id ( ) ,
106
122
references_in_identity ( id, fixed_data, & mut intermediate_cache) ,
107
123
) ;
108
124
}
109
- identities
125
+
126
+ items
110
127
. iter ( )
111
- . flat_map ( |( id, row) | {
112
- references_per_identity[ id] . iter ( ) . map ( move |reference| {
113
- let name = fixed_data. column_name ( & reference. poly_id ) . to_string ( ) ;
114
- let fat_ref = AlgebraicReference {
115
- name,
116
- poly_id : reference. poly_id ,
117
- next : reference. next ,
118
- } ;
119
- let var = Variable :: from_reference ( & fat_ref, * row) ;
120
- ( var, QueueItem ( * id, * row) )
121
- } )
128
+ . flat_map ( |item| {
129
+ let variables = match item {
130
+ QueueItem :: Identity ( id, row) => {
131
+ references_in_identity ( id, fixed_data, & mut intermediate_cache)
132
+ . into_iter ( )
133
+ . map ( |r| {
134
+ let name = fixed_data. column_name ( & r. poly_id ) . to_string ( ) ;
135
+ Variable :: from_reference ( & r. with_name ( name) , * row)
136
+ } )
137
+ . collect_vec ( )
138
+ }
139
+ QueueItem :: Assignment ( a) => {
140
+ variables_in_assignment ( a, fixed_data, & mut intermediate_cache)
141
+ }
142
+ } ;
143
+ variables. into_iter ( ) . map ( move |v| ( v, item. clone ( ) ) )
122
144
} )
123
145
. into_group_map ( )
124
146
}
@@ -184,3 +206,22 @@ fn references_in_expression<'a, T: FieldElement>(
184
206
)
185
207
. unique ( )
186
208
}
209
+
210
+ /// Returns a vector of all variables that occur in the assignment.
211
+ fn variables_in_assignment < ' a , T : FieldElement > (
212
+ assignment : & Assignment < ' a , T > ,
213
+ fixed_data : & ' a FixedData < ' a , T > ,
214
+ intermediate_cache : & mut HashMap < AlgebraicReferenceThin , Vec < AlgebraicReferenceThin > > ,
215
+ ) -> Vec < Variable > {
216
+ let rhs_var = match & assignment. rhs {
217
+ VariableOrValue :: Variable ( v) => Some ( v. clone ( ) ) ,
218
+ VariableOrValue :: Value ( _) => None ,
219
+ } ;
220
+ references_in_expression ( assignment. lhs , fixed_data, intermediate_cache)
221
+ . map ( |r| {
222
+ let name = fixed_data. column_name ( & r. poly_id ) . to_string ( ) ;
223
+ Variable :: from_reference ( & r. with_name ( name) , assignment. row_offset )
224
+ } )
225
+ . chain ( rhs_var)
226
+ . collect ( )
227
+ }
0 commit comments