Skip to content

Commit 538eaad

Browse files
authored
Merge pull request #682 from Veykril/veykril/push-vntqlkpyzmqm
Fix Disambiguator- and IdentityMap hashing
2 parents 54a1740 + e725854 commit 538eaad

File tree

3 files changed

+230
-12
lines changed

3 files changed

+230
-12
lines changed

src/tracked_struct.rs

+103-12
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ impl Identity {
187187
}
188188

189189
/// Stores the data that (almost) uniquely identifies a tracked struct.
190-
/// This includes the ingredient index of that struct type plus the hash of its id fields.
190+
/// This includes the ingredient index of that struct type plus the hash of its untracked fields.
191191
/// This is mapped to a disambiguator -- a value that starts as 0 but increments each round,
192192
/// allowing for multiple tracked structs with the same hash and ingredient_index
193193
/// created within the query to each have a unique id.
@@ -222,10 +222,7 @@ impl IdentityMap {
222222
pub(crate) fn insert(&mut self, key: Identity, id: Id) -> Option<Id> {
223223
use hashbrown::hash_map::RawEntryMut;
224224

225-
let eq_modulo_hash = |k: &Identity| {
226-
k.ingredient_index == key.ingredient_index && k.disambiguator == key.disambiguator
227-
};
228-
let entry = self.map.raw_entry_mut().from_hash(key.hash, eq_modulo_hash);
225+
let entry = self.map.raw_entry_mut().from_hash(key.hash, |k| *k == key);
229226
match entry {
230227
RawEntryMut::Occupied(mut occupied) => Some(occupied.insert(id)),
231228
RawEntryMut::Vacant(vacant) => {
@@ -236,12 +233,9 @@ impl IdentityMap {
236233
}
237234

238235
pub(crate) fn get(&self, key: &Identity) -> Option<Id> {
239-
let eq_modulo_hash = |k: &Identity| {
240-
k.ingredient_index == key.ingredient_index && k.disambiguator == key.disambiguator
241-
};
242236
self.map
243237
.raw_entry()
244-
.from_hash(key.hash, eq_modulo_hash)
238+
.from_hash(key.hash, |k| *k == *key)
245239
.map(|(_, &v)| v)
246240
}
247241

@@ -318,8 +312,7 @@ impl DisambiguatorMap {
318312
pub(crate) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator {
319313
use hashbrown::hash_map::RawEntryMut;
320314

321-
let eq_modulo_hash = |k: &IdentityHash| k.ingredient_index == key.ingredient_index;
322-
let entry = self.map.raw_entry_mut().from_hash(key.hash, eq_modulo_hash);
315+
let entry = self.map.raw_entry_mut().from_hash(key.hash, |k| *k == key);
323316
let disambiguator = match entry {
324317
RawEntryMut::Occupied(occupied) => occupied.into_mut(),
325318
RawEntryMut::Vacant(vacant) => {
@@ -388,7 +381,7 @@ where
388381

389382
let identity = Identity {
390383
hash: identity_hash.hash,
391-
ingredient_index: self.ingredient_index,
384+
ingredient_index: identity_hash.ingredient_index,
392385
disambiguator,
393386
};
394387

@@ -845,3 +838,101 @@ where
845838
&self.syncs
846839
}
847840
}
841+
842+
#[cfg(test)]
843+
mod tests {
844+
use super::*;
845+
846+
#[test]
847+
fn disambiguate_map_works() {
848+
let mut d = DisambiguatorMap::default();
849+
// set up all 4 permutations of differing field values
850+
let h1 = IdentityHash {
851+
ingredient_index: IngredientIndex::from(0),
852+
hash: 0,
853+
};
854+
let h2 = IdentityHash {
855+
ingredient_index: IngredientIndex::from(1),
856+
hash: 0,
857+
};
858+
let h3 = IdentityHash {
859+
ingredient_index: IngredientIndex::from(0),
860+
hash: 1,
861+
};
862+
let h4 = IdentityHash {
863+
ingredient_index: IngredientIndex::from(1),
864+
hash: 1,
865+
};
866+
assert_eq!(d.disambiguate(h1), Disambiguator(0));
867+
assert_eq!(d.disambiguate(h1), Disambiguator(1));
868+
assert_eq!(d.disambiguate(h2), Disambiguator(0));
869+
assert_eq!(d.disambiguate(h2), Disambiguator(1));
870+
assert_eq!(d.disambiguate(h3), Disambiguator(0));
871+
assert_eq!(d.disambiguate(h3), Disambiguator(1));
872+
assert_eq!(d.disambiguate(h4), Disambiguator(0));
873+
assert_eq!(d.disambiguate(h4), Disambiguator(1));
874+
}
875+
876+
#[test]
877+
fn identity_map_works() {
878+
let mut d = IdentityMap::default();
879+
// set up all 8 permutations of differing field values
880+
let i1 = Identity {
881+
ingredient_index: IngredientIndex::from(0),
882+
hash: 0,
883+
disambiguator: Disambiguator(0),
884+
};
885+
let i2 = Identity {
886+
ingredient_index: IngredientIndex::from(1),
887+
hash: 0,
888+
disambiguator: Disambiguator(0),
889+
};
890+
let i3 = Identity {
891+
ingredient_index: IngredientIndex::from(0),
892+
hash: 1,
893+
disambiguator: Disambiguator(0),
894+
};
895+
let i4 = Identity {
896+
ingredient_index: IngredientIndex::from(1),
897+
hash: 1,
898+
disambiguator: Disambiguator(0),
899+
};
900+
let i5 = Identity {
901+
ingredient_index: IngredientIndex::from(0),
902+
hash: 0,
903+
disambiguator: Disambiguator(1),
904+
};
905+
let i6 = Identity {
906+
ingredient_index: IngredientIndex::from(1),
907+
hash: 0,
908+
disambiguator: Disambiguator(1),
909+
};
910+
let i7 = Identity {
911+
ingredient_index: IngredientIndex::from(0),
912+
hash: 1,
913+
disambiguator: Disambiguator(1),
914+
};
915+
let i8 = Identity {
916+
ingredient_index: IngredientIndex::from(1),
917+
hash: 1,
918+
disambiguator: Disambiguator(1),
919+
};
920+
assert_eq!(d.insert(i1, Id::from_u32(0)), None);
921+
assert_eq!(d.insert(i2, Id::from_u32(1)), None);
922+
assert_eq!(d.insert(i3, Id::from_u32(2)), None);
923+
assert_eq!(d.insert(i4, Id::from_u32(3)), None);
924+
assert_eq!(d.insert(i5, Id::from_u32(4)), None);
925+
assert_eq!(d.insert(i6, Id::from_u32(5)), None);
926+
assert_eq!(d.insert(i7, Id::from_u32(6)), None);
927+
assert_eq!(d.insert(i8, Id::from_u32(7)), None);
928+
929+
assert_eq!(d.get(&i1), Some(Id::from_u32(0)));
930+
assert_eq!(d.get(&i2), Some(Id::from_u32(1)));
931+
assert_eq!(d.get(&i3), Some(Id::from_u32(2)));
932+
assert_eq!(d.get(&i4), Some(Id::from_u32(3)));
933+
assert_eq!(d.get(&i5), Some(Id::from_u32(4)));
934+
assert_eq!(d.get(&i6), Some(Id::from_u32(5)));
935+
assert_eq!(d.get(&i7), Some(Id::from_u32(6)));
936+
assert_eq!(d.get(&i8), Some(Id::from_u32(7)));
937+
}
938+
}

src/update.rs

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use std::{
55
sync::Arc,
66
};
77

8+
use rayon::iter::Either;
9+
810
use crate::Revision;
911

1012
/// This is used by the macro generated code.
@@ -348,6 +350,24 @@ where
348350
}
349351
}
350352

353+
unsafe impl<L, R> Update for Either<L, R>
354+
where
355+
L: Update,
356+
R: Update,
357+
{
358+
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
359+
let old_value = unsafe { &mut *old_pointer };
360+
match (old_value, new_value) {
361+
(Either::Left(old), Either::Left(new)) => L::maybe_update(old, new),
362+
(Either::Right(old), Either::Right(new)) => R::maybe_update(old, new),
363+
(old_value, new_value) => {
364+
*old_value = new_value;
365+
true
366+
}
367+
}
368+
}
369+
}
370+
351371
macro_rules! fallback_impl {
352372
($($t:ty,)*) => {
353373
$(

tests/tracked_struct_disambiguates.rs

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//! Test that disambiguation works, that is when we have a revision where we track multiple structs
2+
//! that have the same hash, we can still differentiate between them.
3+
#![allow(warnings)]
4+
5+
use std::hash::Hash;
6+
7+
use rayon::iter::Either;
8+
use salsa::Setter;
9+
10+
#[salsa::input]
11+
struct MyInput {
12+
field: u32,
13+
}
14+
15+
#[salsa::input]
16+
struct MyInputs {
17+
field: Vec<MyInput>,
18+
}
19+
20+
#[salsa::tracked]
21+
struct TrackedStruct<'db> {
22+
field: DumbHashable,
23+
}
24+
25+
#[salsa::tracked]
26+
struct TrackedStruct2<'db> {
27+
field: DumbHashable,
28+
}
29+
30+
#[derive(Debug, Clone)]
31+
pub struct DumbHashable {
32+
field: u32,
33+
}
34+
35+
impl Eq for DumbHashable {}
36+
impl PartialEq for DumbHashable {
37+
fn eq(&self, other: &Self) -> bool {
38+
self.field == other.field
39+
}
40+
}
41+
42+
// Force collisions, note that this is still a correct implementation wrt. PartialEq / Eq above
43+
// as keep the property that k1 == k2 -> hash(k1) == hash(k2)
44+
impl Hash for DumbHashable {
45+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
46+
(self.field % 3).hash(state);
47+
}
48+
}
49+
50+
fn alternate(
51+
db: &dyn salsa::Database,
52+
input: MyInput,
53+
) -> Either<TrackedStruct<'_>, TrackedStruct2<'_>> {
54+
if input.field(db) % 2 == 0 {
55+
Either::Left(TrackedStruct::new(
56+
db,
57+
DumbHashable {
58+
field: input.field(db),
59+
},
60+
))
61+
} else {
62+
Either::Right(TrackedStruct2::new(
63+
db,
64+
DumbHashable {
65+
field: input.field(db),
66+
},
67+
))
68+
}
69+
}
70+
71+
#[salsa::tracked]
72+
fn batch(
73+
db: &dyn salsa::Database,
74+
inputs: MyInputs,
75+
) -> Vec<Either<TrackedStruct<'_>, TrackedStruct2<'_>>> {
76+
inputs
77+
.field(db)
78+
.iter()
79+
.map(|input| alternate(db, input.clone()))
80+
.collect()
81+
}
82+
83+
#[test]
84+
fn execute() {
85+
let mut db = salsa::DatabaseImpl::new();
86+
let inputs = MyInputs::new(
87+
&db,
88+
(0..1024)
89+
.into_iter()
90+
.map(|i| MyInput::new(&db, i))
91+
.collect(),
92+
);
93+
let trackeds = batch(&db, inputs);
94+
for (id, tracked) in trackeds.into_iter().enumerate() {
95+
assert_eq!(id % 2 == 0, tracked.is_left());
96+
assert_eq!(id % 2 != 0, tracked.is_right());
97+
}
98+
for input in inputs.field(&db) {
99+
let prev = input.field(&db);
100+
input.set_field(&mut db).to(prev);
101+
}
102+
let trackeds = batch(&db, inputs);
103+
for (id, tracked) in trackeds.into_iter().enumerate() {
104+
assert_eq!(id % 2 == 0, tracked.is_left());
105+
assert_eq!(id % 2 != 0, tracked.is_right());
106+
}
107+
}

0 commit comments

Comments
 (0)