Skip to content

Commit 8fceee0

Browse files
authored
Add remove_if method on SantaCache (#868)
New method needed to support efficient, single pass cache algorithms in upcoming changes. Part of SNT-353
1 parent 4270dea commit 8fceee0

2 files changed

Lines changed: 291 additions & 0 deletions

File tree

Source/common/SantaCache.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,57 @@ class SantaCache {
215215
*/
216216
bool contains(const KeyT &key) const { return contains(key, nullptr); }
217217

218+
/**
219+
Remove entries matching a predicate. The predicate receives each key
220+
and a mutable reference to the value. Return true to remove the entry.
221+
222+
Buckets are locked one at a time (unlike foreach which locks all),
223+
so the predicate may safely call methods on other SantaCache instances.
224+
225+
@warning The predicate MUST NOT call methods on this same SantaCache
226+
instance. The bucket lock is held during the predicate call, and
227+
re-entering the same cache would attempt to re-lock the same
228+
os_unfair_lock, which is undefined behavior.
229+
230+
@param predicate Called for each entry under its bucket lock.
231+
Return true to remove the entry.
232+
233+
@return The number of entries removed.
234+
*/
235+
uint64_t remove_if(std::function<bool(const KeyT &, ValueT &)> predicate) {
236+
assert(predicate != nullptr);
237+
238+
uint64_t removed = 0;
239+
240+
for (uint32_t i = 0; i < bucket_count_; ++i) {
241+
struct bucket *bucket = &buckets_[i];
242+
lock(bucket);
243+
244+
struct entry *entry = bucket->head;
245+
struct entry *prev = nullptr;
246+
while (entry != nullptr) {
247+
struct entry *next = entry->next;
248+
if (predicate(entry->key, entry->value)) {
249+
if (prev) {
250+
prev->next = next;
251+
} else {
252+
bucket->head = next;
253+
}
254+
delete entry;
255+
count_.fetch_sub(1, std::memory_order_relaxed);
256+
++removed;
257+
} else {
258+
prev = entry;
259+
}
260+
entry = next;
261+
}
262+
263+
unlock(bucket);
264+
}
265+
266+
return removed;
267+
}
268+
218269
/**
219270
Remove all entries and free bucket memory.
220271

Source/common/SantaCacheTest.mm

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,246 @@ - (void)testConcurrentRemoveAndGet {
810810
delete sut;
811811
}
812812

813+
- (void)testRemoveIfSelectiveRemoval {
814+
SantaCache<uint64_t, uint64_t> sut;
815+
sut.set(1, 10);
816+
sut.set(2, 20);
817+
sut.set(3, 30);
818+
sut.set(4, 40);
819+
sut.set(5, 50);
820+
821+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
822+
return value > 30;
823+
});
824+
825+
XCTAssertEqual(removed, 2);
826+
XCTAssertEqual(sut.count(), 3);
827+
XCTAssertEqual(sut.get(1), 10);
828+
XCTAssertEqual(sut.get(2), 20);
829+
XCTAssertEqual(sut.get(3), 30);
830+
XCTAssertEqual(sut.get(4), 0);
831+
XCTAssertEqual(sut.get(5), 0);
832+
}
833+
834+
- (void)testRemoveIfMutatesThenDecides {
835+
SantaCache<uint64_t, uint64_t> sut;
836+
sut.set(1, 1);
837+
sut.set(2, 1);
838+
sut.set(3, 2);
839+
840+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
841+
value += 1;
842+
return value >= 3;
843+
});
844+
845+
XCTAssertEqual(removed, 1);
846+
XCTAssertEqual(sut.count(), 2);
847+
XCTAssertEqual(sut.get(1), 2);
848+
XCTAssertEqual(sut.get(2), 2);
849+
XCTAssertEqual(sut.get(3), 0);
850+
}
851+
852+
- (void)testRemoveIfNoMatches {
853+
SantaCache<uint64_t, uint64_t> sut;
854+
sut.set(1, 10);
855+
sut.set(2, 20);
856+
857+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
858+
return false;
859+
});
860+
861+
XCTAssertEqual(removed, 0);
862+
XCTAssertEqual(sut.count(), 2);
863+
XCTAssertEqual(sut.get(1), 10);
864+
XCTAssertEqual(sut.get(2), 20);
865+
}
866+
867+
- (void)testRemoveIfAllMatch {
868+
SantaCache<uint64_t, uint64_t> sut;
869+
sut.set(1, 10);
870+
sut.set(2, 20);
871+
sut.set(3, 30);
872+
873+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
874+
return true;
875+
});
876+
877+
XCTAssertEqual(removed, 3);
878+
XCTAssertEqual(sut.count(), 0);
879+
}
880+
881+
- (void)testRemoveIfSharedPtr {
882+
SantaCache<uint64_t, std::shared_ptr<uint64_t>> sut;
883+
sut.set(1, std::make_shared<uint64_t>(11));
884+
sut.set(2, std::make_shared<uint64_t>(22));
885+
sut.set(3, std::make_shared<uint64_t>(33));
886+
887+
std::weak_ptr<uint64_t> weak = sut.get(2);
888+
XCTAssertFalse(weak.expired());
889+
890+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, std::shared_ptr<uint64_t> &value) {
891+
return *value == 22;
892+
});
893+
894+
XCTAssertEqual(removed, 1);
895+
XCTAssertEqual(sut.count(), 2);
896+
XCTAssertTrue(weak.expired());
897+
XCTAssertEqual(*sut.get(1), 11);
898+
XCTAssertEqual(sut.get(2), nullptr);
899+
XCTAssertEqual(*sut.get(3), 33);
900+
}
901+
902+
- (void)testRemoveIfEmptyCache {
903+
SantaCache<uint64_t, uint64_t> sut;
904+
905+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
906+
return true;
907+
});
908+
909+
XCTAssertEqual(removed, 0);
910+
XCTAssertEqual(sut.count(), 0);
911+
}
912+
913+
- (void)testRemoveIfMultipleEntriesPerBucket {
914+
SantaCache<uint64_t, uint64_t> sut(10, 10);
915+
sut.set(1, 10);
916+
sut.set(2, 20);
917+
sut.set(3, 30);
918+
sut.set(4, 40);
919+
sut.set(5, 50);
920+
921+
uint64_t removed = sut.remove_if(^bool(const uint64_t &key, uint64_t &value) {
922+
return value % 20 == 0;
923+
});
924+
925+
XCTAssertEqual(removed, 2);
926+
XCTAssertEqual(sut.count(), 3);
927+
XCTAssertEqual(sut.get(1), 10);
928+
XCTAssertEqual(sut.get(2), 0);
929+
XCTAssertEqual(sut.get(3), 30);
930+
XCTAssertEqual(sut.get(4), 0);
931+
XCTAssertEqual(sut.get(5), 50);
932+
}
933+
934+
// Tests that remove_if works correctly while concurrent set/get operations
935+
// are modifying the cache. Writers continuously set entries, readers verify
936+
// they never see corrupted data, and a remove_if thread periodically sweeps.
937+
- (void)testConcurrentRemoveIfWithSetAndGet {
938+
auto sut = new SantaCache<uint64_t, uint64_t>(20000);
939+
const int kKeyRange = 1000;
940+
auto stop = new std::atomic<bool>{false};
941+
942+
// Pre-populate so readers have entries to find
943+
for (int i = 0; i < kKeyRange; ++i) {
944+
sut->set(i, i + 1);
945+
}
946+
947+
dispatch_group_t group = dispatch_group_create();
948+
949+
// 2 writer threads: continuously set entries
950+
for (int t = 0; t < 2; ++t) {
951+
dispatch_group_enter(group);
952+
dispatch_async(dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0), ^{
953+
while (!stop->load(std::memory_order_relaxed)) {
954+
for (int i = 0; i < kKeyRange; ++i) {
955+
sut->set(i, i + 1);
956+
}
957+
}
958+
dispatch_group_leave(group);
959+
});
960+
}
961+
962+
// 2 reader threads: continuously get, expect either 0 (removed) or i+1
963+
for (int t = 0; t < 2; ++t) {
964+
dispatch_group_enter(group);
965+
dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INTERACTIVE, 0), ^{
966+
while (!stop->load(std::memory_order_relaxed)) {
967+
for (int i = 0; i < kKeyRange; ++i) {
968+
uint64_t val = sut->get(i);
969+
XCTAssertTrue(val == 0 || val == (uint64_t)(i + 1), @"Corrupted value %llu for key %d",
970+
val, i);
971+
}
972+
}
973+
dispatch_group_leave(group);
974+
});
975+
}
976+
977+
// 1 remove_if thread: periodically sweep even-keyed entries
978+
dispatch_group_enter(group);
979+
dispatch_async(dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0), ^{
980+
while (!stop->load(std::memory_order_relaxed)) {
981+
sut->remove_if(^bool(const uint64_t &key, uint64_t &value) {
982+
return key % 2 == 0;
983+
});
984+
}
985+
dispatch_group_leave(group);
986+
});
987+
988+
usleep(500000); // 500ms
989+
stop->store(true, std::memory_order_relaxed);
990+
991+
XCTAssertFalse(dispatch_group_wait(group, dispatch_time(DISPATCH_TIME_NOW, 10 * NSEC_PER_SEC)),
992+
@"Timed out");
993+
994+
// Cache must still be usable
995+
sut->set(42, 99);
996+
XCTAssertEqual(sut->get(42), 99);
997+
998+
delete stop;
999+
delete sut;
1000+
}
1001+
1002+
// Tests remove_if racing against set-triggered auto-clear. The cache is kept
1003+
// near max capacity so that set() frequently triggers clear() while remove_if
1004+
// is iterating. Validates the interaction between remove_if's per-bucket
1005+
// locking and clear()'s lock-all-buckets strategy.
1006+
- (void)testConcurrentRemoveIfWithAutoOverflowClear {
1007+
const uint64_t kMaxSize = 100;
1008+
auto sut = new SantaCache<uint64_t, uint64_t>(kMaxSize);
1009+
auto stop = new std::atomic<bool>{false};
1010+
1011+
dispatch_group_t group = dispatch_group_create();
1012+
1013+
// 2 writer threads: race to fill and overflow the cache, triggering auto-clear
1014+
for (int t = 0; t < 2; ++t) {
1015+
dispatch_group_enter(group);
1016+
dispatch_async(dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0), ^{
1017+
uint64_t i = t * 10000;
1018+
while (!stop->load(std::memory_order_relaxed)) {
1019+
sut->set(i++, 42);
1020+
}
1021+
dispatch_group_leave(group);
1022+
});
1023+
}
1024+
1025+
// 1 remove_if thread: continuously sweep
1026+
dispatch_group_enter(group);
1027+
dispatch_async(dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0), ^{
1028+
while (!stop->load(std::memory_order_relaxed)) {
1029+
sut->remove_if(^bool(const uint64_t &key, uint64_t &value) {
1030+
return key % 3 == 0;
1031+
});
1032+
}
1033+
dispatch_group_leave(group);
1034+
});
1035+
1036+
usleep(500000); // 500ms
1037+
stop->store(true, std::memory_order_relaxed);
1038+
1039+
XCTAssertFalse(dispatch_group_wait(group, dispatch_time(DISPATCH_TIME_NOW, 10 * NSEC_PER_SEC)),
1040+
@"Timed out");
1041+
1042+
// Count must never exceed max
1043+
XCTAssertLessThanOrEqual(sut->count(), kMaxSize);
1044+
1045+
// Cache must still be functional
1046+
sut->set(999999, 77);
1047+
XCTAssertEqual(sut->get(999999), 77);
1048+
1049+
delete stop;
1050+
delete sut;
1051+
}
1052+
8131053
// Tests that count() stays consistent when multiple threads are concurrently
8141054
// adding and removing entries.
8151055
- (void)testConcurrentCountConsistency {

0 commit comments

Comments
 (0)