Skip to content

Commit a5028dd

Browse files
committed
Fold key into qtree
1 parent 6ad96fe commit a5028dd

File tree

2 files changed

+37
-34
lines changed

2 files changed

+37
-34
lines changed

src/cache.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl<C: 'static, Out: 'static> Cache<C, Out> {
159159
/// The internal data for a cache.
160160
pub struct CacheData<C, Out> {
161161
/// Maps from hashes to memoized results.
162-
entries: HashMap<u128, QuestionTree<C, u128, (Out, Vec<C>)>>,
162+
tree: QuestionTree<C, u128, (Out, Vec<C>)>,
163163
}
164164

165165
impl<C: PartialEq, Out: 'static> CacheData<C, Out> {
@@ -169,7 +169,7 @@ impl<C: PartialEq, Out: 'static> CacheData<C, Out> {
169169
In: Input<Call = C>,
170170
C: Clone + Hash,
171171
{
172-
self.entries.get(&key)?.get(|c| input.call(c.clone()))
172+
self.tree.get(key, |c| input.call(c.clone()))
173173
}
174174

175175
/// Insert an entry into the cache.
@@ -184,15 +184,13 @@ impl<C: PartialEq, Out: 'static> CacheData<C, Out> {
184184
In: Input<Call = C>,
185185
C: Clone + Hash,
186186
{
187-
self.entries
188-
.entry(key)
189-
.or_default()
190-
.insert(recording.immutable, (output, recording.mutable))
187+
self.tree
188+
.insert(key, recording.immutable, (output, recording.mutable))
191189
}
192190
}
193191

194192
impl<C, Out> Default for CacheData<C, Out> {
195193
fn default() -> Self {
196-
Self { entries: HashMap::new() }
194+
Self { tree: QuestionTree::new() }
197195
}
198196
}

src/qtree.rs

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ pub struct QuestionTree<Q, A, T> {
8383
questions: Slab<Q>,
8484
results: Slab<T>,
8585
links: HashMap<(usize, A), State>,
86-
start: Option<State>,
86+
start: HashMap<u128, State>,
8787
}
8888

8989
impl<Q, A, T> QuestionTree<Q, A, T> {
@@ -92,7 +92,7 @@ impl<Q, A, T> QuestionTree<Q, A, T> {
9292
questions: Slab::new(),
9393
results: Slab::new(),
9494
links: HashMap::new(),
95-
start: None,
95+
start: HashMap::new(),
9696
}
9797
}
9898
}
@@ -102,8 +102,8 @@ where
102102
Q: Hash + Clone,
103103
A: Hash + Eq + Clone,
104104
{
105-
pub fn get(&self, mut oracle: impl FnMut(&Q) -> A) -> Option<&T> {
106-
let mut state = self.start?;
105+
pub fn get(&self, key: u128, mut oracle: impl FnMut(&Q) -> A) -> Option<&T> {
106+
let mut state = *self.start.get(&key)?;
107107
loop {
108108
match state.kind() {
109109
StateKind::Result(r) => return Some(self.results.get(r).unwrap()),
@@ -118,18 +118,22 @@ where
118118

119119
pub fn insert(
120120
&mut self,
121+
key: u128,
121122
mut sequence: LookaheadSequence<Q, A>,
122123
value: T,
123124
) -> Result<(), InsertError> {
124-
let mut state = self.start;
125+
let mut state = self.start.get(&key).copied();
125126
let mut predecessor = None;
126127

127128
loop {
128129
let pair = if state.is_none() || predecessor.is_some() {
129130
let Some((q, a)) = sequence.next() else { break };
130131
let qi = self.questions.alloc(q);
131132
let new = State::question(qi);
132-
self.link(predecessor.take(), new);
133+
if state.is_none() {
134+
self.start.insert(key, new);
135+
}
136+
self.link(&state, predecessor.take(), new);
133137
state = Some(new);
134138
(qi, a)
135139
} else {
@@ -150,21 +154,21 @@ where
150154
}
151155
}
152156

153-
if predecessor.is_none() && self.start.is_some() {
157+
if predecessor.is_none() && state.is_some() {
154158
return Err(InsertError::AlreadyExists);
155159
}
156160

157161
let ri = self.results.alloc(value);
158-
self.link(predecessor, State::result(ri));
162+
let target = State::result(ri);
163+
if state.is_none() {
164+
self.start.insert(key, target);
165+
}
166+
self.link(&state, predecessor, target);
159167

160168
Ok(())
161169
}
162170

163-
fn link(&mut self, from: Option<(usize, A)>, target: State) {
164-
if self.start.is_none() {
165-
self.start = Some(target);
166-
}
167-
171+
fn link(&mut self, state: &Option<State>, from: Option<(usize, A)>, target: State) {
168172
if let Some(pair) = from {
169173
self.links.insert(pair, target);
170174
}
@@ -263,36 +267,37 @@ mod tests {
263267
#[test]
264268
fn test_question_tree() {
265269
let mut tree = QuestionTree::<char, u128, &'static str>::new();
266-
tree.insert(s([('a', 10), ('b', 15)]), "first").unwrap();
267-
tree.insert(s([('a', 10), ('b', 20)]), "second").unwrap();
268-
tree.insert(s([('a', 15), ('c', 15)]), "third").unwrap();
270+
tree.insert(0, s([('a', 10), ('b', 15)]), "first").unwrap();
271+
tree.insert(0, s([('a', 10), ('b', 20)]), "second").unwrap();
272+
tree.insert(0, s([('a', 15), ('c', 15)]), "third").unwrap();
269273
assert_eq!(
270-
tree.get(|&c| match c {
274+
tree.get(0, |&c| match c {
271275
'a' => 10,
272276
'b' => 15,
273277
_ => 20,
274278
}),
275279
Some(&"first")
276280
);
277281
assert_eq!(
278-
tree.get(|&c| match c {
282+
tree.get(0, |&c| match c {
279283
'a' => 10,
280284
_ => 20,
281285
}),
282286
Some(&"second")
283287
);
284-
assert_eq!(tree.get(|_| 15), Some(&"third"));
285-
assert_eq!(tree.get(|_| 10), None);
288+
assert_eq!(tree.get(0, |_| 15), Some(&"third"));
289+
assert_eq!(tree.get(0, |_| 10), None);
286290
}
287291

288292
#[test]
289293
fn test_question_tree_pull_forward() {
290294
let mut tree = QuestionTree::<char, u128, &'static str>::new();
291-
tree.insert(s([('a', 10), ('b', 15)]), "first").unwrap();
292-
tree.insert(s([('a', 10), ('c', 15), ('b', 20)]), "second").unwrap();
293-
tree.insert(s([('a', 15), ('b', 30), ('c', 15)]), "third").unwrap();
295+
tree.insert(0, s([('a', 10), ('b', 15)]), "first").unwrap();
296+
tree.insert(0, s([('a', 10), ('c', 15), ('b', 20)]), "second")
297+
.unwrap();
298+
tree.insert(0, s([('a', 15), ('b', 30), ('c', 15)]), "third").unwrap();
294299
assert_eq!(
295-
tree.get(|&c| match c {
300+
tree.get(0, |&c| match c {
296301
'a' => 10,
297302
'b' => 20,
298303
'c' => 15,
@@ -301,7 +306,7 @@ mod tests {
301306
Some(&"second")
302307
);
303308
assert_eq!(
304-
tree.get(|&c| match c {
309+
tree.get(0, |&c| match c {
305310
'a' => 15,
306311
'b' => 30,
307312
'c' => 15,
@@ -327,15 +332,15 @@ mod tests {
327332
let mut kept = Vec::new();
328333
for case in cases.iter() {
329334
let &(ref numbers, value) = case;
330-
match tree.insert(s(sequence(numbers)), value) {
335+
match tree.insert(0, s(sequence(numbers)), value) {
331336
Ok(()) => kept.push(case),
332337
Err(InsertError::AlreadyExists) => {}
333338
Err(InsertError::WrongQuestion) => {} // Err(error) => panic!("{error:?}"),
334339
}
335340
}
336341
for (numbers, value) in kept {
337342
let map: HashMap<u64, u16> = sequence(numbers).collect();
338-
assert_eq!(tree.get(|s| map[s]), Some(value));
343+
assert_eq!(tree.get(0, |s| map[s]), Some(value));
339344
}
340345
}
341346

0 commit comments

Comments
 (0)