Skip to content

Commit f5a885d

Browse files
committed
Further optimize DFS/BFS algorithms
1 parent ca7da87 commit f5a885d

File tree

2 files changed

+166
-68
lines changed

2 files changed

+166
-68
lines changed

crates/algos/src/bfs.rs

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,85 @@ use crate::prelude::*;
66

77
pub fn bfs_directed<NI, G>(
88
graph: &G,
9-
node_id: NI,
9+
node_ids: impl IntoIterator<Item = NI>,
1010
direction: Direction,
1111
) -> DirectedBreadthFirst<'_, G, NI>
1212
where
1313
NI: Idx + Hash,
1414
G: Graph<NI> + DirectedDegrees<NI> + DirectedNeighbors<NI> + Sync,
1515
{
16-
DirectedBreadthFirst::new(graph, node_id, direction)
16+
DirectedBreadthFirst::new(graph, node_ids, direction)
1717
}
1818

1919
pub struct DirectedBreadthFirst<'a, G, NI> {
2020
graph: &'a G,
21+
seen: BitVec<usize>,
2122
visited: BitVec<usize>,
2223
queue: VecDeque<NI>,
2324
direction: Direction,
2425
}
2526

2627
impl<'a, G, NI> DirectedBreadthFirst<'a, G, NI>
2728
where
28-
NI: Idx + Hash,
29+
NI: Idx + Hash + std::fmt::Debug,
2930
G: Graph<NI> + DirectedNeighbors<NI> + Sync,
3031
{
31-
pub fn new(graph: &'a G, node_id: NI, direction: Direction) -> Self {
32+
pub fn new(graph: &'a G, node_ids: impl IntoIterator<Item = NI>, direction: Direction) -> Self {
33+
let bitvec = BitVec::repeat(false, graph.node_count().index());
34+
let visited = bitvec.clone();
35+
36+
let mut seen = bitvec;
37+
let mut queue = VecDeque::new();
38+
Self::enqueue_into(&mut seen, &mut queue, node_ids);
39+
3240
Self {
3341
graph,
34-
visited: BitVec::repeat(false, graph.node_count().index()),
35-
queue: VecDeque::from_iter([node_id]),
42+
seen,
43+
visited,
44+
queue,
3645
direction,
3746
}
3847
}
3948

4049
fn dequeue(&mut self) -> Option<NI> {
4150
loop {
4251
let node_id = self.queue.pop_front()?;
43-
4452
if !self.visited.replace(node_id.index(), true) {
4553
return Some(node_id);
4654
}
4755
}
4856
}
4957

50-
fn enqueue_out_neighbors(&mut self, node_id: NI) {
51-
let neighbors = self
58+
fn enqueue_into(
59+
seen: &mut BitVec<usize>,
60+
queue: &mut VecDeque<NI>,
61+
node_ids: impl IntoIterator<Item = NI>,
62+
) {
63+
for node_id in node_ids {
64+
if !seen.replace(node_id.index(), true) {
65+
queue.push_back(node_id);
66+
}
67+
}
68+
}
69+
70+
fn enqueue_out_neighbors_of(&mut self, node_id: NI) {
71+
let node_ids = self
5272
.graph
5373
.out_neighbors(node_id)
54-
.filter(|&node_id| !self.visited[node_id.index()]);
55-
self.queue.extend(neighbors);
74+
.copied()
75+
.filter(|node_id| !self.visited[node_id.index()]);
76+
77+
Self::enqueue_into(&mut self.seen, &mut self.queue, node_ids);
5678
}
5779

58-
fn enqueue_in_neighbors(&mut self, node_id: NI) {
59-
let neighbors = self
80+
fn enqueue_in_neighbors_of(&mut self, node_id: NI) {
81+
let node_ids = self
6082
.graph
6183
.in_neighbors(node_id)
62-
.filter(|&node_id| !self.visited[node_id.index()]);
63-
self.queue.extend(neighbors);
84+
.copied()
85+
.filter(|node_id| !self.visited[node_id.index()]);
86+
87+
Self::enqueue_into(&mut self.seen, &mut self.queue, node_ids);
6488
}
6589
}
6690

@@ -75,28 +99,32 @@ where
7599
let node_id = self.dequeue()?;
76100

77101
match self.direction {
78-
Direction::Outgoing => self.enqueue_out_neighbors(node_id),
79-
Direction::Incoming => self.enqueue_in_neighbors(node_id),
102+
Direction::Outgoing => self.enqueue_out_neighbors_of(node_id),
103+
Direction::Incoming => self.enqueue_in_neighbors_of(node_id),
80104
Direction::Undirected => {
81-
self.enqueue_out_neighbors(node_id);
82-
self.enqueue_in_neighbors(node_id);
105+
self.enqueue_out_neighbors_of(node_id);
106+
self.enqueue_in_neighbors_of(node_id);
83107
}
84108
}
85109

86110
Some(node_id)
87111
}
88112
}
89113

90-
pub fn bfs_undirected<NI, G>(graph: &G, node_id: NI) -> UndirectedBreadthFirst<'_, G, NI>
114+
pub fn bfs_undirected<NI, G>(
115+
graph: &G,
116+
node_ids: impl IntoIterator<Item = NI>,
117+
) -> UndirectedBreadthFirst<'_, G, NI>
91118
where
92119
NI: Idx + Hash,
93120
G: Graph<NI> + UndirectedDegrees<NI> + UndirectedNeighbors<NI> + Sync,
94121
{
95-
UndirectedBreadthFirst::new(graph, node_id)
122+
UndirectedBreadthFirst::new(graph, node_ids)
96123
}
97124

98125
pub struct UndirectedBreadthFirst<'a, G, NI> {
99126
graph: &'a G,
127+
seen: BitVec<usize>,
100128
visited: BitVec<usize>,
101129
queue: VecDeque<NI>,
102130
}
@@ -106,11 +134,19 @@ where
106134
NI: Idx + Hash + std::fmt::Debug,
107135
G: Graph<NI> + UndirectedNeighbors<NI> + Sync,
108136
{
109-
pub fn new(graph: &'a G, node_id: NI) -> Self {
137+
pub fn new(graph: &'a G, node_ids: impl IntoIterator<Item = NI>) -> Self {
138+
let bitvec = BitVec::repeat(false, graph.node_count().index());
139+
let visited = bitvec.clone();
140+
141+
let mut seen = bitvec;
142+
let mut queue = VecDeque::new();
143+
Self::enqueue_into(&mut seen, &mut queue, node_ids);
144+
110145
Self {
111146
graph,
112-
visited: BitVec::repeat(false, graph.node_count().index()),
113-
queue: VecDeque::from_iter([node_id]),
147+
seen,
148+
visited,
149+
queue,
114150
}
115151
}
116152

@@ -124,15 +160,26 @@ where
124160
}
125161
}
126162

127-
fn enqueue_neighbors(&mut self, node_id: NI) {
128-
let neighbors = self
163+
fn enqueue_into(
164+
seen: &mut BitVec<usize>,
165+
queue: &mut VecDeque<NI>,
166+
node_ids: impl IntoIterator<Item = NI>,
167+
) {
168+
for node_id in node_ids {
169+
if !seen.replace(node_id.index(), true) {
170+
queue.push_back(node_id);
171+
}
172+
}
173+
}
174+
175+
fn enqueue_neighbors_of(&mut self, node_id: NI) {
176+
let node_ids = self
129177
.graph
130178
.neighbors(node_id)
179+
.copied()
131180
.filter(|&node_id| !self.visited[node_id.index()]);
132181

133-
let neighbors: Vec<_> = neighbors.collect();
134-
135-
self.queue.extend(neighbors);
182+
Self::enqueue_into(&mut self.seen, &mut self.queue, node_ids);
136183
}
137184
}
138185

@@ -146,18 +193,18 @@ where
146193
fn next(&mut self) -> Option<Self::Item> {
147194
let node_id = self.dequeue()?;
148195

149-
self.enqueue_neighbors(node_id);
196+
self.enqueue_neighbors_of(node_id);
150197

151198
Some(node_id)
152199
}
153200
}
154201

155202
#[cfg(test)]
156203
mod tests {
157-
use super::*;
158-
use crate::prelude::{CsrLayout, GraphBuilder};
204+
use graph::prelude::{CsrLayout, GraphBuilder};
159205

160206
use super::*;
207+
161208
mod directed {
162209
use super::*;
163210

@@ -168,7 +215,7 @@ mod tests {
168215
.edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 1), (3, 1)])
169216
.build();
170217

171-
let actual: Vec<usize> = bfs_directed(&graph, 0, Direction::Outgoing).collect();
218+
let actual: Vec<usize> = bfs_directed(&graph, [0], Direction::Outgoing).collect();
172219
let expected: Vec<usize> = vec![0, 1, 2, 3];
173220

174221
assert_eq!(actual, expected);
@@ -181,7 +228,7 @@ mod tests {
181228
.edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 1), (2, 1), (3, 1)])
182229
.build();
183230

184-
let actual: Vec<usize> = bfs_directed(&graph, 0, Direction::Outgoing).collect();
231+
let actual: Vec<usize> = bfs_directed(&graph, [0], Direction::Outgoing).collect();
185232
let expected: Vec<usize> = vec![0, 1, 2, 3];
186233

187234
assert_eq!(actual, expected);
@@ -195,7 +242,7 @@ mod tests {
195242
.edges(vec![(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 1), (3, 1)])
196243
.build();
197244

198-
let actual: Vec<usize> = bfs_undirected(&graph, 0).collect();
245+
let actual: Vec<usize> = bfs_undirected(&graph, [0]).collect();
199246
let expected: Vec<usize> = vec![0, 1, 2, 3];
200247

201248
assert_eq!(actual, expected);

0 commit comments

Comments
 (0)