Skip to content

Commit 935b785

Browse files
authored
Merge pull request #51 from oflatt/oflatt-nits
Fix nits and add toolchain file
2 parents cf9a2c3 + df97830 commit 935b785

7 files changed

Lines changed: 89 additions & 69 deletions

File tree

rust-toolchain.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[toolchain]
2+
channel = "1.87.0"

src/extract/faster_bottom_up.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ Notably, insert/pop operations have O(1) expected amortized runtime complexity.
6969
Thanks @Bastacyclop for the implementation!
7070
*/
7171
#[derive(Clone)]
72-
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
7372
pub(crate) struct UniqueQueue<T>
7473
where
7574
T: Eq + std::hash::Hash + Clone,

src/extract/faster_greedy_dag.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl FasterGreedyDagExtractor {
3636
let mut childrens_classes = node
3737
.children
3838
.iter()
39-
.map(|c| egraph.nid_to_cid(&c).clone())
39+
.map(|c| egraph.nid_to_cid(c).clone())
4040
.collect::<Vec<ClassId>>();
4141
childrens_classes.sort();
4242
childrens_classes.dedup();
@@ -59,19 +59,19 @@ impl FasterGreedyDagExtractor {
5959
.iter()
6060
.max_by_key(|s| costs.get(s).unwrap().costs.len())
6161
.unwrap();
62-
let mut result = costs.get(&id_of_biggest).unwrap().costs.clone();
62+
let mut result = costs.get(id_of_biggest).unwrap().costs.clone();
6363
for child_cid in &childrens_classes {
6464
if child_cid == id_of_biggest {
6565
continue;
6666
}
6767

6868
let next_cost = &costs.get(child_cid).unwrap().costs;
6969
for (key, value) in next_cost.iter() {
70-
result.insert(key.clone(), value.clone());
70+
result.insert(key.clone(), *value);
7171
}
7272
}
7373

74-
let contains = result.contains_key(&cid);
74+
let contains = result.contains_key(cid);
7575
result.insert(cid.clone(), node.cost);
7676

7777
let result_cost = if contains {
@@ -80,11 +80,11 @@ impl FasterGreedyDagExtractor {
8080
result.values().sum()
8181
};
8282

83-
return CostSet {
83+
CostSet {
8484
costs: result,
8585
total: result_cost,
8686
choice: node_id.clone(),
87-
};
87+
}
8888
}
8989
}
9090

@@ -151,7 +151,6 @@ Notably, insert/pop operations have O(1) expected amortized runtime complexity.
151151
Thanks @Bastacyclop for the implementation!
152152
*/
153153
#[derive(Clone)]
154-
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
155154
pub(crate) struct UniqueQueue<T>
156155
where
157156
T: Eq + std::hash::Hash + Clone,

src/extract/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl ExtractionResult {
8282
}
8383

8484
// No cycles
85-
assert!(self.find_cycles(&egraph, &egraph.root_eclasses).is_empty());
85+
assert!(self.find_cycles(egraph, &egraph.root_eclasses).is_empty());
8686

8787
// Nodes should match the class they are selected into.
8888
for (cid, nid) in &self.choices {

src/extract/prio_queue.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Extractor for PrioQueueExtractor {
5454
result.choose(class_id.clone(), node_id.clone());
5555
costs.insert(class_id.clone(), cost);
5656
for p in parents[class_id].iter() {
57-
if costs.contains_key(&n2c(p)) {
57+
if costs.contains_key(n2c(p)) {
5858
continue;
5959
}
6060

@@ -79,15 +79,16 @@ mod prio {
7979
#[derive(PartialEq, Eq, Debug)]
8080
struct WithOrdRev<T: Eq, U: Ord>(pub T, pub U);
8181

82-
impl<T: Eq, U: Ord> PartialOrd for WithOrdRev<T, U> {
83-
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
82+
impl<T: Eq, U: Ord> Ord for WithOrdRev<T, U> {
83+
fn cmp(&self, other: &Self) -> Ordering {
8484
// It's the other way around, because we want a min-heap!
85-
other.1.partial_cmp(&self.1)
85+
other.1.cmp(&self.1)
8686
}
8787
}
88-
impl<T: Eq, U: Ord> Ord for WithOrdRev<T, U> {
89-
fn cmp(&self, other: &Self) -> Ordering {
90-
self.partial_cmp(&other).unwrap()
88+
89+
impl<T: Eq, U: Ord> PartialOrd for WithOrdRev<T, U> {
90+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
91+
Some(self.cmp(other))
9192
}
9293
}
9394

src/main.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@ use std::io::Write;
1313
use std::path::PathBuf;
1414

1515
pub type Cost = NotNan<f64>;
16-
pub const INFINITY: Cost = unsafe { NotNan::new_unchecked(std::f64::INFINITY) };
16+
pub const INFINITY: Cost = unsafe { NotNan::new_unchecked(f64::INFINITY) };
1717

1818
#[derive(PartialEq, Eq)]
1919
enum Optimal {
2020
Tree,
21-
DAG,
21+
#[cfg(feature = "ilp-cbc")]
22+
Dag,
2223
Neither,
2324
}
2425

2526
struct ExtractorDetail {
2627
extractor: Box<dyn Extractor>,
28+
#[cfg_attr(not(test), allow(dead_code))]
2729
optimal: Optimal,
2830
use_for_bench: bool,
2931
}
@@ -75,7 +77,7 @@ fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
7577
"ilp-cbc-timeout",
7678
ExtractorDetail {
7779
extractor: extract::ilp_cbc::CbcExtractorWithTimeout::<10>.boxed(),
78-
optimal: Optimal::DAG,
80+
optimal: Optimal::Dag,
7981
use_for_bench: true,
8082
},
8183
),
@@ -84,7 +86,7 @@ fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
8486
"ilp-cbc",
8587
ExtractorDetail {
8688
extractor: extract::ilp_cbc::CbcExtractor.boxed(),
87-
optimal: Optimal::DAG,
89+
optimal: Optimal::Dag,
8890
use_for_bench: false, // takes >10 hours sometimes
8991
},
9092
),
@@ -93,7 +95,7 @@ fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
9395
"faster-ilp-cbc-timeout",
9496
ExtractorDetail {
9597
extractor: extract::faster_ilp_cbc::FasterCbcExtractorWithTimeout::<10>.boxed(),
96-
optimal: Optimal::DAG,
98+
optimal: Optimal::Dag,
9799
use_for_bench: true,
98100
},
99101
),
@@ -102,14 +104,14 @@ fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
102104
"faster-ilp-cbc",
103105
ExtractorDetail {
104106
extractor: extract::faster_ilp_cbc::FasterCbcExtractor.boxed(),
105-
optimal: Optimal::DAG,
107+
optimal: Optimal::Dag,
106108
use_for_bench: true,
107109
},
108110
),
109111
]
110112
.into_iter()
111113
.collect();
112-
return extractors;
114+
extractors
113115
}
114116

115117
fn main() {

src/test.rs

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ use rand::Rng;
1010
pub const ELABORATE_TESTING: bool = false;
1111

1212
pub fn test_save_path(name: &str) -> String {
13-
return if ELABORATE_TESTING {
13+
if ELABORATE_TESTING {
1414
format!("/dev/shm/{}_egraph.json", name)
1515
} else {
1616
"".to_string()
17-
};
17+
}
1818
}
1919

2020
// generates a float between 0 and 1
@@ -38,12 +38,12 @@ pub fn generate_random_egraph() -> EGraph {
3838
let get_semi_random_cost = |nodes: &Vec<Node>| -> Cost {
3939
let mut rng = rand::thread_rng();
4040

41-
if nodes.len() > 0 && rng.gen_bool(0.1) {
42-
return nodes[rng.gen_range(0..nodes.len())].cost;
41+
if !nodes.is_empty() && rng.gen_bool(0.1) {
42+
nodes[rng.gen_range(0..nodes.len())].cost
4343
} else if rng.gen_bool(0.05) {
44-
return Cost::default();
44+
Cost::default()
4545
} else {
46-
return generate_random_not_nan() * 100.0;
46+
generate_random_not_nan() * 100.0
4747
}
4848
};
4949

@@ -56,7 +56,7 @@ pub fn generate_random_egraph() -> EGraph {
5656

5757
nodes.push(Node {
5858
op: "operation".to_string(),
59-
children: children,
59+
children,
6060
eclass: eclass.to_string().clone().into(),
6161
cost: get_semi_random_cost(&nodes),
6262
});
@@ -83,8 +83,8 @@ pub fn generate_random_egraph() -> EGraph {
8383

8484
let mut egraph = EGraph::default();
8585

86-
for i in 0..nodes.len() {
87-
egraph.add_node(id2nid(i), nodes[i].clone());
86+
for (i, node) in nodes.iter().enumerate() {
87+
egraph.add_node(id2nid(i), node.clone());
8888
}
8989

9090
// Set roots
@@ -106,42 +106,65 @@ pub fn generate_random_egraph() -> EGraph {
106106
* Checks that the extractions are valid.
107107
*/
108108

109+
#[cfg(feature = "ilp-cbc")]
110+
fn check_dag_optimal(
111+
egraph: &EGraph,
112+
optimal_dag: &[Box<dyn Extractor>],
113+
others: &[Box<dyn Extractor>],
114+
optimal_tree_cost: Option<Cost>,
115+
) {
116+
let mut optimal_dag_cost: Option<Cost> = None;
117+
118+
for e in optimal_dag {
119+
let extract = e.extract(egraph, &egraph.root_eclasses);
120+
extract.check(egraph);
121+
let dag_cost = extract.dag_cost(egraph, &egraph.root_eclasses);
122+
let tree_cost = extract.tree_cost(egraph, &egraph.root_eclasses);
123+
if optimal_dag_cost.is_none() {
124+
optimal_dag_cost = Some(dag_cost);
125+
continue;
126+
}
127+
128+
assert!(
129+
(dag_cost.into_inner() - optimal_dag_cost.unwrap().into_inner()).abs()
130+
< EPSILON_ALLOWANCE
131+
);
132+
133+
assert!(
134+
tree_cost.into_inner() + EPSILON_ALLOWANCE > optimal_dag_cost.unwrap().into_inner()
135+
);
136+
}
137+
138+
if let (Some(dag_cost), Some(tree_cost)) = (optimal_dag_cost, optimal_tree_cost) {
139+
assert!(dag_cost < tree_cost + EPSILON_ALLOWANCE);
140+
}
141+
142+
if let Some(optimal_dag_cost) = optimal_dag_cost {
143+
for e in others {
144+
let extract = e.extract(egraph, &egraph.root_eclasses);
145+
let dag_cost = extract.dag_cost(egraph, &egraph.root_eclasses);
146+
// The optimal dag should be <= any extractor's dag cost
147+
assert!(optimal_dag_cost <= dag_cost + EPSILON_ALLOWANCE);
148+
}
149+
}
150+
}
151+
109152
fn check_optimal_results<I: Iterator<Item = EGraph>>(egraphs: I) {
153+
#[cfg(feature = "ilp-cbc")]
110154
let mut optimal_dag: Vec<Box<dyn Extractor>> = Default::default();
111155
let mut optimal_tree: Vec<Box<dyn Extractor>> = Default::default();
112156
let mut others: Vec<Box<dyn Extractor>> = Default::default();
113157

114158
for (_, ed) in extractors().into_iter() {
115159
match ed.optimal {
116-
Optimal::DAG => optimal_dag.push(ed.extractor),
160+
#[cfg(feature = "ilp-cbc")]
161+
Optimal::Dag => optimal_dag.push(ed.extractor),
117162
Optimal::Tree => optimal_tree.push(ed.extractor),
118163
Optimal::Neither => others.push(ed.extractor),
119164
}
120165
}
121166

122167
for egraph in egraphs {
123-
let mut optimal_dag_cost: Option<Cost> = None;
124-
125-
for e in &optimal_dag {
126-
let extract = e.extract(&egraph, &egraph.root_eclasses);
127-
extract.check(&egraph);
128-
let dag_cost = extract.dag_cost(&egraph, &egraph.root_eclasses);
129-
let tree_cost = extract.tree_cost(&egraph, &egraph.root_eclasses);
130-
if optimal_dag_cost.is_none() {
131-
optimal_dag_cost = Some(dag_cost);
132-
continue;
133-
}
134-
135-
assert!(
136-
(dag_cost.into_inner() - optimal_dag_cost.unwrap().into_inner()).abs()
137-
< EPSILON_ALLOWANCE
138-
);
139-
140-
assert!(
141-
tree_cost.into_inner() + EPSILON_ALLOWANCE > optimal_dag_cost.unwrap().into_inner()
142-
);
143-
}
144-
145168
let mut optimal_tree_cost: Option<Cost> = None;
146169

147170
for e in &optimal_tree {
@@ -159,26 +182,19 @@ fn check_optimal_results<I: Iterator<Item = EGraph>>(egraphs: I) {
159182
);
160183
}
161184

162-
if optimal_dag_cost.is_some() && optimal_tree_cost.is_some() {
163-
assert!(optimal_dag_cost.unwrap() < optimal_tree_cost.unwrap() + EPSILON_ALLOWANCE);
164-
}
165-
166185
for e in &others {
167186
let extract = e.extract(&egraph, &egraph.root_eclasses);
168187
extract.check(&egraph);
169188
let tree_cost = extract.tree_cost(&egraph, &egraph.root_eclasses);
170-
let dag_cost = extract.dag_cost(&egraph, &egraph.root_eclasses);
171189

172190
// The optimal tree cost should be <= any extractor's tree cost.
173-
if optimal_tree_cost.is_some() {
174-
assert!(optimal_tree_cost.unwrap() <= tree_cost + EPSILON_ALLOWANCE);
175-
}
176-
177-
if optimal_dag_cost.is_some() {
178-
// The optimal dag should be less <= any extractor's dag cost
179-
assert!(optimal_dag_cost.unwrap() <= dag_cost + EPSILON_ALLOWANCE);
191+
if let Some(optimal_tree_cost) = optimal_tree_cost {
192+
assert!(optimal_tree_cost <= tree_cost + EPSILON_ALLOWANCE);
180193
}
181194
}
195+
196+
#[cfg(feature = "ilp-cbc")]
197+
check_dag_optimal(&egraph, &optimal_dag, &others, optimal_tree_cost);
182198
}
183199
}
184200

@@ -201,6 +217,7 @@ fn run_on_test_egraphs() {
201217

202218
#[test]
203219
#[should_panic]
220+
#[allow(clippy::assertions_on_constants)]
204221
fn check_assert_enabled() {
205222
assert!(false);
206223
}
@@ -210,8 +227,8 @@ macro_rules! create_optimal_check_tests {
210227
$(
211228
#[test]
212229
fn $name() {
213-
let optimal_dag_found = extractors().into_iter().any(|(_, ed)| ed.optimal == Optimal::DAG);
214-
let iterations = if optimal_dag_found { 100 } else { 10000 };
230+
// Fewer iterations when ilp-cbc is enabled since it's slow
231+
let iterations = if cfg!(feature = "ilp-cbc") { 100 } else { 10000 };
215232
let egraphs = (0..iterations).map(|_| generate_random_egraph());
216233
check_optimal_results(egraphs);
217234
}

0 commit comments

Comments
 (0)