Skip to content

Commit 499bce5

Browse files
committed
Prescient
1 parent 876b6e6 commit 499bce5

File tree

6 files changed

+101
-27
lines changed

6 files changed

+101
-27
lines changed

macros/src/track.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ fn create_variants(methods: &[Method]) -> TokenStream {
198198
};
199199

200200
quote! {
201-
#[derive(Clone, PartialEq, Hash)]
201+
#[derive(Debug, Clone, PartialEq, Hash)]
202202
pub struct __ComemoCall(__ComemoVariant);
203203

204204
impl ::comemo::internal::Call for __ComemoCall {
@@ -207,7 +207,7 @@ fn create_variants(methods: &[Method]) -> TokenStream {
207207
}
208208
}
209209

210-
#[derive(Clone, PartialEq, Hash)]
210+
#[derive(Debug, Clone, PartialEq, Hash)]
211211
#[allow(non_camel_case_types)]
212212
enum __ComemoVariant {
213213
#(#variants,)*
@@ -362,16 +362,18 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
362362
#[track_caller]
363363
#[inline]
364364
#vis #sig {
365-
let __comemo_variant = __ComemoVariant::#name(#(#args.to_owned()),*);
366365
let (__comemo_value, __comemo_sink) = ::comemo::internal::#to_parts;
367-
let output = __comemo_value.#name(#(#args,)*);
368366
if let Some(__comemo_sink) = __comemo_sink {
367+
let __comemo_variant = __ComemoVariant::#name(#(#args.to_owned()),*);
368+
let output = __comemo_value.#name(#(#args,)*);
369369
__comemo_sink(
370370
__ComemoCall(__comemo_variant),
371371
::comemo::internal::hash(&output),
372372
);
373+
output
374+
} else {
375+
__comemo_value.#name(#(#args,)*)
373376
}
374-
output
375377
}
376378
}
377379
}

src/cache.rs

Lines changed: 89 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
use std::cell::UnsafeCell;
12
use std::collections::HashMap;
23
use std::hash::Hash;
4+
use std::marker::PhantomData;
5+
use std::sync::atomic::AtomicUsize;
36

47
use bumpalo::Bump;
58
use once_cell::sync::Lazy;
@@ -34,6 +37,64 @@ impl<C> Default for Recording<C> {
3437
}
3538
}
3639

40+
pub fn write_prescience(mut sink: impl std::io::Write) {
41+
let locked = PRESCIENCE_WRITE.data.lock();
42+
let slice: &[u32] = &locked;
43+
let buf: &[u8] = unsafe {
44+
std::slice::from_raw_parts(
45+
slice.as_ptr().cast(),
46+
slice.len() * (u32::BITS / u8::BITS) as usize,
47+
)
48+
};
49+
sink.write_all(buf).unwrap();
50+
}
51+
52+
struct PrescienceWrite {
53+
data: Mutex<Vec<u32>>,
54+
}
55+
56+
impl PrescienceWrite {
57+
fn hit(&self, i: usize) {
58+
self.data.lock().push(i as u32);
59+
}
60+
61+
fn miss(&self) {
62+
self.data.lock().push(u32::MAX);
63+
}
64+
}
65+
66+
static PRESCIENCE_WRITE: PrescienceWrite =
67+
PrescienceWrite { data: Mutex::new(Vec::new()) };
68+
69+
pub fn put_prescience(data: &'static [u8]) {
70+
let slice: &'static [u32] =
71+
unsafe { std::slice::from_raw_parts(data.as_ptr().cast(), data.len() / 4) };
72+
unsafe {
73+
*PRESCIENCE_READ.data.get() = slice;
74+
}
75+
}
76+
77+
struct PrescienceRead {
78+
data: UnsafeCell<&'static [u32]>,
79+
i: AtomicUsize,
80+
}
81+
82+
impl PrescienceRead {
83+
fn get(&self) -> Option<u32> {
84+
let i = self.i.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
85+
let slice = unsafe { *self.data.get() };
86+
if slice.is_empty() {
87+
return None;
88+
}
89+
Some(slice[i])
90+
}
91+
}
92+
93+
unsafe impl Sync for PrescienceRead {}
94+
95+
static PRESCIENCE_READ: PrescienceRead =
96+
PrescienceRead { data: UnsafeCell::new(&[]), i: AtomicUsize::new(0) };
97+
3798
/// Execute a function or use a cached result for it.
3899
pub fn memoized<'c, In, Out, F>(
39100
mut input: In,
@@ -48,17 +109,13 @@ where
48109
Out: Clone + 'static,
49110
F: FnOnce(In::Tracked<'c>) -> Out,
50111
{
51-
// Early bypass if memoization is disabled.
52-
// Hopefully the compiler will optimize this away, if the condition is constant.
53-
if !enabled {
54-
// Execute the function with the new constraints hooked in.
55-
let output = func(input.retrack_noop());
56-
57-
// Ensure that the last call was a miss during testing.
58-
#[cfg(feature = "testing")]
59-
LAST_WAS_HIT.with(|cell| cell.set(false));
60-
61-
return output;
112+
if let Some(i) = PRESCIENCE_READ.get() {
113+
if i != u32::MAX {
114+
return cache.0.read().values[i as usize].clone();
115+
}
116+
let value = func(input.retrack_noop());
117+
cache.0.write().values.push(value.clone());
118+
return value;
62119
}
63120

64121
// Compute the hash of the input's key part.
@@ -69,7 +126,9 @@ where
69126
};
70127

71128
// Check if there is a cached output.
72-
if let Some((value, mutable)) = cache.0.read().lookup::<In>(key, &input) {
129+
if let Some((i, value, mutable)) = cache.0.read().lookup::<In>(key, &input) {
130+
PRESCIENCE_WRITE.hit(i);
131+
73132
#[cfg(feature = "testing")]
74133
LAST_WAS_HIT.with(|cell| cell.set(true));
75134

@@ -81,6 +140,8 @@ where
81140
return value.clone();
82141
}
83142

143+
PRESCIENCE_WRITE.miss();
144+
84145
// Execute the function with the new constraints hooked in.
85146
let sink = |call: In::Call, hash: u128| {
86147
if call.is_mutable() {
@@ -159,17 +220,21 @@ impl<C: 'static, Out: 'static> Cache<C, Out> {
159220
/// The internal data for a cache.
160221
pub struct CacheData<C, Out> {
161222
/// Maps from hashes to memoized results.
162-
entries: HashMap<u128, QuestionTree<C, u128, (Out, Vec<C>)>>,
223+
entries: HashMap<u128, QuestionTree<C, u128, (usize, Vec<C>)>>,
224+
values: Vec<Out>,
163225
}
164226

165227
impl<C: PartialEq, Out: 'static> CacheData<C, Out> {
166228
/// Look for a matching entry in the cache.
167-
fn lookup<In>(&self, key: u128, input: &In) -> Option<&(Out, Vec<C>)>
229+
fn lookup<In>(&self, key: u128, input: &In) -> Option<(usize, &Out, &Vec<C>)>
168230
where
169231
In: Input<Call = C>,
170232
C: Clone + Hash,
171233
{
172-
self.entries.get(&key)?.get(|c| input.call(c.clone()))
234+
self.entries
235+
.get(&key)?
236+
.get(|c| input.call(c.clone()))
237+
.map(|(i, c)| (*i, &self.values[*i], c))
173238
}
174239

175240
/// Insert an entry into the cache.
@@ -184,15 +249,21 @@ impl<C: PartialEq, Out: 'static> CacheData<C, Out> {
184249
In: Input<Call = C>,
185250
C: Clone + Hash,
186251
{
187-
self.entries
252+
let i = self.values.len();
253+
let res = self
254+
.entries
188255
.entry(key)
189256
.or_default()
190-
.insert(recording.immutable, (output, recording.mutable))
257+
.insert(recording.immutable, (i, recording.mutable));
258+
if res.is_ok() {
259+
self.values.push(output);
260+
}
261+
res
191262
}
192263
}
193264

194265
impl<C, Out> Default for CacheData<C, Out> {
195266
fn default() -> Self {
196-
Self { entries: HashMap::new() }
267+
Self { entries: HashMap::new(), values: Vec::new() }
197268
}
198269
}

src/constraint.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
use std::fmt::Debug;
12
use std::hash::Hash;
23

34
use siphasher::sip128::{Hasher128, SipHasher13};
45

56
/// A call to a tracked function.
6-
pub trait Call: Hash + PartialEq + Clone + Send + Sync {
7+
pub trait Call: Debug + Hash + PartialEq + Clone + Send + Sync {
78
/// Whether the call is mutable.
89
fn is_mutable(&self) -> bool;
910
}

src/input.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ macro_rules! args_input {
268268
}
269269
}
270270

271-
#[derive(PartialEq, Clone, Hash)]
271+
#[derive(Debug, PartialEq, Clone, Hash)]
272272
pub enum ArgsCall<$($param),*> {
273273
$($param($param),)*
274274
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ mod prehashed;
9292
mod qtree;
9393
mod track;
9494

95-
pub use crate::cache::evict;
95+
pub use crate::cache::{evict, put_prescience, write_prescience};
9696
pub use crate::prehashed::Prehashed;
9797
pub use crate::track::{Track, Tracked, TrackedMut, Validate};
9898
pub use comemo_macros::{memoize, track};

tests/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ impl Emitter {
529529
}
530530

531531
/// A non-copy struct that is passed by value to a tracked method.
532-
#[derive(Clone, PartialEq, Hash)]
532+
#[derive(Debug, Clone, PartialEq, Hash)]
533533
struct Heavy(String);
534534

535535
/// Test a tracked method that is impure.

0 commit comments

Comments
 (0)