diff --git a/Cargo.lock b/Cargo.lock index 072a480138..8aa837c0c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4273,6 +4273,7 @@ version = "2.1.7" dependencies = [ "anyhow", "futures-util", + "itertools 0.14.0", "lazy_static", "moka", "rand 0.8.5", @@ -4282,6 +4283,7 @@ dependencies = [ "rivet-metrics", "rivet-pools", "rivet-util", + "scc", "serde", "serde_json", "thiserror 1.0.69", diff --git a/engine/packages/cache/Cargo.toml b/engine/packages/cache/Cargo.toml index 0bde7eedb3..01e931cb9b 100644 --- a/engine/packages/cache/Cargo.toml +++ b/engine/packages/cache/Cargo.toml @@ -8,6 +8,7 @@ edition.workspace = true [dependencies] anyhow.workspace = true futures-util.workspace = true +itertools.workspace = true lazy_static.workspace = true moka.workspace = true rivet-cache-result.workspace = true @@ -16,8 +17,9 @@ rivet-env.workspace = true rivet-metrics.workspace = true rivet-pools.workspace = true rivet-util.workspace = true -serde.workspace = true +scc.workspace = true serde_json.workspace = true +serde.workspace = true thiserror.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/engine/packages/cache/src/driver.rs b/engine/packages/cache/src/driver.rs index dfb75653b2..915dc8eb35 100644 --- a/engine/packages/cache/src/driver.rs +++ b/engine/packages/cache/src/driver.rs @@ -23,7 +23,7 @@ impl Driver { pub async fn get<'a>( &'a self, base_key: &'a str, - keys: Vec, + keys: &[RawCacheKey], ) -> Result>, Error> { match self { Driver::InMemory(d) => d.get(base_key, keys).await, @@ -156,14 +156,14 @@ impl InMemoryDriver { pub async fn get<'a>( &'a self, _base_key: &'a str, - keys: Vec, + keys: &[RawCacheKey], ) -> Result>, Error> { let mut result = Vec::with_capacity(keys.len()); // Async block for metrics async { for key in keys { - result.push(self.cache.get(&*key).await.map(|x| x.value.clone())); + result.push(self.cache.get(&**key).await.map(|x| x.value.clone())); } } .instrument(tracing::info_span!("get")) diff --git a/engine/packages/cache/src/getter_ctx.rs b/engine/packages/cache/src/getter_ctx.rs index cbfd0b8a53..6ef21a0c07 100644 --- a/engine/packages/cache/src/getter_ctx.rs +++ b/engine/packages/cache/src/getter_ctx.rs @@ -1,13 +1,10 @@ -use std::fmt::Debug; +use std::{collections::HashMap, fmt::Debug}; use super::*; /// Entry for a single value that is going to be read/written to the cache. #[derive(Debug)] -pub(super) struct GetterCtxKey { - /// `CacheKey` that will be used to build Redis keys. - pub(super) key: K, - +pub(super) struct GetterCtxEntry { /// The value that was read from the cache or getter. value: Option, @@ -23,81 +20,68 @@ pub struct GetterCtx where K: CacheKey, { - /// The name of the service-specific key to write this cached value to. For - /// example, a team get service would use the "team_profile" key to store - /// the profile a "team_members" to store a cache of members. - /// - /// This is local to the service & source hash that caches this value. - #[allow(unused)] - base_key: String, - - /// The keys to get/populate from the cache. - keys: Vec>, + /// The entries to get/populate from the cache. + entries: HashMap>, } impl GetterCtx where K: CacheKey, { - pub(super) fn new(base_key: String, keys: Vec) -> Self { + pub(super) fn new(keys: Vec) -> Self { GetterCtx { - base_key, - keys: { - // Create deduplicated ctx keys - let mut ctx_keys = Vec::>::new(); - for key in keys { - if !ctx_keys.iter().any(|x| x.key == key) { - ctx_keys.push(GetterCtxKey { - key, + entries: keys + .into_iter() + .map(|k| { + ( + k, + GetterCtxEntry { value: None, from_cache: false, - }); - } - } - ctx_keys - }, + }, + ) + }) + .collect(), } } + pub(super) fn merge(&mut self, other: GetterCtx) { + self.entries.extend(other.entries); + } + pub(super) fn into_values(self) -> Vec<(K, V)> { - self.keys + self.entries .into_iter() - .filter_map(|k| { - if let Some(v) = k.value { - Some((k.key, v)) - } else { - None - } - }) + .filter_map(|(k, x)| x.value.map(|v| (k, v))) .collect() } - /// All keys. - pub(super) fn keys(&self) -> &[GetterCtxKey] { - &self.keys[..] + /// All entries. + pub(super) fn entries(&self) -> impl Iterator)> { + self.entries.iter() } - /// If all keys have an associated value. - pub(super) fn all_keys_have_value(&self) -> bool { - self.keys.iter().all(|x| x.value.is_some()) + /// If all entries have an associated value. + pub(super) fn all_entries_have_value(&self) -> bool { + self.entries.iter().all(|(_, x)| x.value.is_some()) } /// Keys that do not have a value yet. pub(super) fn unresolved_keys(&self) -> Vec { - self.keys + self.entries .iter() - .filter(|x| x.value.is_none()) - .map(|x| x.key.clone()) + .filter(|(_, x)| x.value.is_none()) + .map(|(k, _)| k.clone()) .collect() } - /// Keys that have been resolved in a getter and need to be written to the + /// Entries that have been resolved in a getter and need to be written to the /// cache. - pub(super) fn values_needing_cache_write(&self) -> Vec<(&GetterCtxKey, &V)> { - self.keys + pub(super) fn entries_needing_cache_write(&self) -> Vec<(&K, &V)> { + self.entries .iter() - .filter(|x| !x.from_cache) - .filter_map(|k| k.value.as_ref().map(|v| (k, v))) + .filter(|(_, x)| !x.from_cache) + .filter_map(|(k, x)| x.value.as_ref().map(|v| (k, v))) .collect() } } @@ -107,32 +91,26 @@ where K: CacheKey, V: Debug, { - /// Sets a value with the value provided from the cache. - pub(super) fn resolve_from_cache(&mut self, idx: usize, value: V) { - if let Some(key) = self.keys.get_mut(idx) { - key.value = Some(value); - key.from_cache = true; + /// Sets an entry with the value provided from the cache. + pub(super) fn resolve_from_cache(&mut self, key: &K, value: V) { + if let Some(entry) = self.entries.get_mut(key) { + entry.value = Some(value); + entry.from_cache = true; } else { - tracing::warn!(?idx, ?value, "resolving cache key index out of range"); + tracing::warn!(?key, ?value, "resolving nonexistent cache entry"); } } - /// Calls the callback with a mutable reference to a given key. Validates - /// that the key does not already have a value. - fn get_key_for_resolve(&mut self, key: &K, cb: impl FnOnce(&mut GetterCtxKey)) { - if let Some(key) = self.keys.iter_mut().find(|x| x.key == *key) { - if key.value.is_some() { - tracing::warn!(?key, "cache key already has value"); + /// Sets a value with the value provided from the getter function. + pub fn resolve(&mut self, key: &K, value: V) { + if let Some(entry) = self.entries.get_mut(key) { + if entry.value.is_some() { + tracing::warn!(?entry, "cache entry already has value"); } else { - cb(key); + entry.value = Some(value); } } else { - tracing::warn!(?key, "resolved value for nonexistent cache key"); + tracing::warn!(?key, "resolved value for nonexistent cache entry"); } } - - /// Sets a value with the value provided from the getter function. - pub fn resolve(&mut self, key: &K, value: V) { - self.get_key_for_resolve(key, |key| key.value = Some(value)); - } } diff --git a/engine/packages/cache/src/inner.rs b/engine/packages/cache/src/inner.rs index aa6f62c417..cb65782dea 100644 --- a/engine/packages/cache/src/inner.rs +++ b/engine/packages/cache/src/inner.rs @@ -1,5 +1,7 @@ use std::{fmt::Debug, sync::Arc}; +use tokio::sync::broadcast; + use super::*; use crate::driver::{Driver, InMemoryDriver}; @@ -8,6 +10,7 @@ pub type Cache = Arc; /// Utility type used to hold information relating to caching. pub struct CacheInner { pub(crate) driver: Driver, + pub(crate) in_flight: scc::HashMap>, pub(crate) ups: Option, } @@ -33,7 +36,11 @@ impl CacheInner { #[tracing::instrument(skip(ups))] pub fn new_in_memory(max_capacity: u64, ups: Option) -> Cache { let driver = Driver::InMemory(InMemoryDriver::new(max_capacity)); - Arc::new(CacheInner { driver, ups }) + Arc::new(CacheInner { + driver, + in_flight: scc::HashMap::new(), + ups, + }) } } diff --git a/engine/packages/cache/src/key.rs b/engine/packages/cache/src/key.rs index 6714a2762d..2cf46cdacb 100644 --- a/engine/packages/cache/src/key.rs +++ b/engine/packages/cache/src/key.rs @@ -1,8 +1,8 @@ use serde::{Deserialize, Serialize}; -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, hash::Hash, ops::Deref}; /// A type that can be serialized in to a key that can be used in the cache. -pub trait CacheKey: Clone + Debug + PartialEq { +pub trait CacheKey: Clone + Debug + Eq + PartialEq + Hash { fn cache_key(&self) -> String; } @@ -85,7 +85,7 @@ impl_to_string!(isize); /// Unlike other types that implement `CacheKey` (which escape special characters like `:` and `\`), /// `RawCacheKey` uses the provided string as-is. This is useful when you already have a properly /// formatted cache key string or need to preserve the exact format without transformations. -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, Hash, Eq, PartialEq)] pub struct RawCacheKey(String); impl CacheKey for RawCacheKey { diff --git a/engine/packages/cache/src/req_config.rs b/engine/packages/cache/src/req_config.rs index cf0ede8073..8baacd55cb 100644 --- a/engine/packages/cache/src/req_config.rs +++ b/engine/packages/cache/src/req_config.rs @@ -1,12 +1,17 @@ -use std::{fmt::Debug, future::Future, result::Result::Ok}; +use std::{fmt::Debug, future::Future, time::Duration}; -use anyhow::*; +use anyhow::Result; +use futures_util::StreamExt; +use itertools::{Either, Itertools}; use serde::{Serialize, de::DeserializeOwned}; -use tracing::Instrument; +use tokio::sync::broadcast; use super::*; use crate::{errors::Error, metrics}; +/// How long to wait for an in flight cache req before proceeding to execute the same req anyway. +const IN_FLIGHT_TIMEOUT: Duration = Duration::from_secs(5); + /// Config specifying how cached values will behave. #[derive(Clone)] pub struct RequestConfig { @@ -74,22 +79,17 @@ impl RequestConfig { .with_label_values(&[base_key.as_str()]) .inc_by(keys.len() as u64); - // Build context. - // - // Drop `keys` bc this is not the same as the keys list in `ctx`, so it should not be used - // again. - let mut ctx = GetterCtx::new(base_key.clone(), keys); + let mut ctx = GetterCtx::new(keys); // Build driver-specific cache keys - let cache_keys = ctx - .keys() - .iter() - .map(|key| self.cache.driver.process_key(&base_key, &key.key)) - .collect::>(); + let (keys, cache_keys): (Vec<_>, Vec<_>) = ctx + .entries() + .map(|(key, _)| (key.clone(), self.cache.driver.process_key(&base_key, key))) + .unzip(); let cache_keys_len = cache_keys.len(); // Attempt to fetch value from cache, fall back to getter - match self.cache.driver.get(&base_key, cache_keys).await { + match self.cache.driver.get(&base_key, &cache_keys).await { Ok(cached_values) => { debug_assert_eq!( cache_keys_len, @@ -97,13 +97,13 @@ impl RequestConfig { "cache returned wrong number of values" ); - // Create the getter ctx and resolve the cached values - for (i, value) in cached_values.into_iter().enumerate() { + // Resolve the cached values + for (key, value) in keys.iter().zip(cached_values.into_iter()) { if let Some(value_bytes) = value { // Try to decode the value using the driver match decoder(&value_bytes) { Ok(value) => { - ctx.resolve_from_cache(i, value); + ctx.resolve_from_cache(key, value); } Err(err) => { tracing::error!(?err, "Failed to decode value"); @@ -113,7 +113,7 @@ impl RequestConfig { } // Fetch remaining values and add to the cached list - if !ctx.all_keys_have_value() { + if !ctx.all_entries_have_value() { // Call the getter let remaining_keys = ctx.unresolved_keys(); let unresolved_len = remaining_keys.len(); @@ -122,27 +122,137 @@ impl RequestConfig { .with_label_values(&[base_key.as_str()]) .inc_by(unresolved_len as u64); - ctx = getter(ctx, remaining_keys).await.map_err(Error::Getter)?; + let mut waiting_keys = Vec::new(); + let mut leased_keys = Vec::new(); + let (broadcast_tx, _) = broadcast::channel::<()>(16); + + // Determine which keys are currently being fetched and not + for key in remaining_keys { + let cache_key = self.cache.driver.process_key(&base_key, &key); + match self.cache.in_flight.entry_async(cache_key).await { + scc::hash_map::Entry::Occupied(broadcast) => { + waiting_keys.push((key, broadcast.subscribe())); + } + scc::hash_map::Entry::Vacant(entry) => { + entry.insert_entry(broadcast_tx.clone()); + leased_keys.push(key); + } + } + } + + let getter2 = getter.clone(); + let cache = self.cache.clone(); + let ctx2 = GetterCtx::new(leased_keys.clone()); + let base_key2 = base_key.clone(); + let leased_keys2 = leased_keys.clone(); + let (ctx2, ctx3) = tokio::try_join!( + async move { + if leased_keys2.is_empty() { + Ok(ctx2) + } else { + getter2(ctx2, leased_keys2).await.map_err(Error::Getter) + } + }, + async move { + let ctx3 = GetterCtx::new( + waiting_keys.iter().map(|(key, _)| key.clone()).collect(), + ); + + // Wait on keys that are being fetched by another cache req + let (succeeded_keys, failed_keys): (Vec<_>, Vec<_>) = + futures_util::stream::iter(waiting_keys) + .map(|(key, mut rx)| async move { + ( + key, + tokio::time::timeout(IN_FLIGHT_TIMEOUT, rx.recv()) + .await + .ok() + .map(|x| x.ok()) + .flatten() + .is_some(), + ) + }) + .buffer_unordered(1024) + .collect::>() + .await + .into_iter() + .partition_map(|(key, succeeded)| { + if succeeded { + let cache_key = + cache.driver.process_key(&base_key2, &key); + Either::Left((key, cache_key)) + } else { + Either::Right(key) + } + }); + let (succeeded_keys, succeeded_cache_keys): (Vec<_>, Vec<_>) = + succeeded_keys.into_iter().unzip(); + + let (cached_values_res, ctx3_res) = tokio::join!( + cache.driver.get(&base_key2, &succeeded_cache_keys), + async { + if failed_keys.is_empty() { + Ok(ctx3) + } else { + getter(ctx3, failed_keys).await.map_err(Error::Getter) + } + }, + ); + let mut ctx3 = ctx3_res?; + + match cached_values_res { + Ok(cached_values) => { + for (key, value) in + succeeded_keys.iter().zip(cached_values.into_iter()) + { + if let Some(value_bytes) = value { + // Try to decode the value using the driver + match decoder(&value_bytes) { + Ok(value) => { + ctx3.resolve_from_cache(key, value); + } + Err(err) => { + tracing::error!(?err, "Failed to decode value"); + } + } + } + } + } + Err(err) => { + tracing::error!(?err, "failed to read batch keys from cache"); + + metrics::CACHE_REQUEST_ERRORS + .with_label_values(&[&base_key2]) + .inc(); + } + } + + Ok(ctx3) + } + )?; + + ctx.merge(ctx2); + ctx.merge(ctx3); // Write the values to cache let expire_at = rivet_util::timestamp::now() + self.ttl; - let values_needing_cache_write = ctx.values_needing_cache_write(); + let entries_needing_cache_write = ctx.entries_needing_cache_write(); tracing::trace!( unresolved_len, - fetched_len = values_needing_cache_write.len(), + fetched_len = entries_needing_cache_write.len(), "writing new values to cache" ); // Convert values to cache bytes - let keys_values = values_needing_cache_write + let entries_values = entries_needing_cache_write .into_iter() .filter_map(|(key, value)| { // Process the key with the appropriate driver - let driver_key = self.cache.driver.process_key(&base_key, &key.key); + let cache_key = self.cache.driver.process_key(&base_key, key); // Try to decode the value using the driver match encoder(value) { - Ok(value_bytes) => Some((driver_key, value_bytes, expire_at)), + Ok(value_bytes) => Some((cache_key, value_bytes, expire_at)), Err(err) => { tracing::error!(?err, "Failed to encode value"); @@ -152,28 +262,26 @@ impl RequestConfig { }) .collect::>(); - if !keys_values.is_empty() { + if !entries_values.is_empty() { let cache = self.cache.clone(); let base_key_clone = base_key.clone(); - let spawn_res = tokio::task::Builder::new().name("cache::write").spawn( - async move { - if let Err(err) = - cache.driver.set(&base_key_clone, keys_values).await - { - tracing::error!(?err, "failed to write to cache"); - } - } - .in_current_span(), - ); - if let Err(err) = spawn_res { - tracing::error!(?err, "failed to spawn cache::write task"); + if let Err(err) = cache.driver.set(&base_key_clone, entries_values).await { + tracing::error!(?err, "failed to write to cache"); } + + let _ = broadcast_tx.send(()); + } + + // Release leases + for key in leased_keys { + let cache_key = self.cache.driver.process_key(&base_key, &key); + self.cache.in_flight.remove_async(&cache_key).await; } } metrics::CACHE_VALUE_EMPTY_TOTAL - .with_label_values(&[base_key]) + .with_label_values(&[&base_key]) .inc_by(ctx.unresolved_keys().len() as u64); Ok(ctx.into_values()) @@ -185,7 +293,7 @@ impl RequestConfig { ); metrics::CACHE_REQUEST_ERRORS - .with_label_values(&[base_key]) + .with_label_values(&[&base_key]) .inc(); // Fall back to the getter since we can't fetch the value from @@ -193,6 +301,10 @@ impl RequestConfig { let keys = ctx.unresolved_keys(); let ctx = getter(ctx, keys).await.map_err(Error::Getter)?; + metrics::CACHE_VALUE_EMPTY_TOTAL + .with_label_values(&[&base_key]) + .inc_by(ctx.unresolved_keys().len() as u64); + Ok(ctx.into_values()) } } diff --git a/engine/packages/cache/tests/fetch.rs b/engine/packages/cache/tests/fetch.rs new file mode 100644 index 0000000000..d5d07cd546 --- /dev/null +++ b/engine/packages/cache/tests/fetch.rs @@ -0,0 +1,91 @@ +use std::{collections::HashSet, sync::Arc, time::Duration}; + +use rand::{Rng, seq::IteratorRandom, thread_rng}; + +fn build_cache() -> rivet_cache::Cache { + rivet_cache::CacheInner::new_in_memory(1000, None) +} + +#[tokio::test(flavor = "multi_thread")] +async fn multiple_keys() { + let cache = build_cache(); + let values = cache + .clone() + .request() + .fetch_all_json_with_keys( + "multiple_keys", + vec!["a", "b", "c"], + |mut cache, keys| async move { + for key in &keys { + cache.resolve(key, format!("{0}{0}{0}", key)); + } + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!(3, values.len(), "missing values"); + for (k, v) in values { + let expected_v = match k { + "a" => "aaa", + "b" => "bbb", + "c" => "ccc", + _ => panic!("unexpected key {}", k), + }; + assert_eq!(expected_v, v, "unexpected value"); + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn smoke_test() { + let cache = build_cache(); + + // Generate random entries for the cache + let mut entries = std::collections::HashMap::new(); + for i in 0..16usize { + entries.insert(i.to_string(), format!("{0}{0}{0}", i)); + } + let entries = Arc::new(entries); + + let parallel_count = 32; // Reduced for faster tests + let barrier = Arc::new(tokio::sync::Barrier::new(parallel_count)); + let mut handles = Vec::new(); + for _ in 0..parallel_count { + let keys = + std::iter::repeat_with(|| entries.keys().choose(&mut thread_rng()).unwrap().clone()) + .take(thread_rng().gen_range(0..8)) + .collect::>(); + let deduplicated_keys = keys.clone().into_iter().collect::>(); + + let entries = entries.clone(); + let cache = cache.clone(); + let barrier = barrier.clone(); + let handle = tokio::spawn(async move { + barrier.wait().await; + let values = cache + .request() + .fetch_all_json_with_keys("smoke_test", keys, move |mut cache, keys| { + let entries = entries.clone(); + async move { + // Reduced sleep for faster tests + tokio::time::sleep(Duration::from_millis(100)).await; + for key in &keys { + cache.resolve(key, entries.get(key).expect("invalid key").clone()); + } + Ok(cache) + } + }) + .await + .unwrap(); + assert_eq!( + deduplicated_keys, + values + .iter() + .map(|x| x.0.clone()) + .collect::>() + ); + }); + handles.push(handle); + } + futures_util::future::try_join_all(handles).await.unwrap(); +} diff --git a/engine/packages/cache/tests/in_flight.rs b/engine/packages/cache/tests/in_flight.rs new file mode 100644 index 0000000000..5710ed6c26 --- /dev/null +++ b/engine/packages/cache/tests/in_flight.rs @@ -0,0 +1,361 @@ +use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, +}; + +fn build_cache() -> rivet_cache::Cache { + rivet_cache::CacheInner::new_in_memory(1000, None) +} + +/// Two concurrent requests for the same key. The second should wait for the +/// first's getter to complete, then read from cache without calling its own getter. +#[tokio::test(flavor = "multi_thread")] +async fn dedup_two_requests() { + let cache = build_cache(); + let call_count = Arc::new(AtomicUsize::new(0)); + let getter_started = Arc::new(tokio::sync::Notify::new()); + + let cache1 = cache.clone(); + let count1 = call_count.clone(); + let started1 = getter_started.clone(); + let task1 = tokio::spawn(async move { + cache1 + .request() + .fetch_one_json( + "in_flight_dedup", + "key1", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count1.clone(); + let started = started1.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + started.notify_one(); + // Hold the lease long enough for task2 to discover the in-flight entry. + tokio::time::sleep(Duration::from_millis(100)).await; + ctx.resolve(&key, "from_getter".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + // Wait until task1's getter is running before launching task2. + getter_started.notified().await; + + let count2 = call_count.clone(); + let task2 = tokio::spawn(async move { + cache + .request() + .fetch_one_json( + "in_flight_dedup", + "key1", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count2.clone(); + async move { + // Should not be reached: key1 is in-flight. + count.fetch_add(1, Ordering::SeqCst); + ctx.resolve(&key, "wrong_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + let (r1, r2) = tokio::join!(task1, task2); + + assert_eq!(r1.unwrap().unwrap(), Some("from_getter".to_string())); + assert_eq!( + r2.unwrap().unwrap(), + Some("from_getter".to_string()), + "waiting request should read the value written to cache by the leased request" + ); + assert_eq!( + call_count.load(Ordering::SeqCst), + 1, + "getter called exactly once despite two concurrent requests for the same key" + ); +} + +/// Three concurrent requests for the same key. Only the first should call the +/// getter; the other two wait and then read from cache. +#[tokio::test(flavor = "multi_thread")] +async fn dedup_multiple_waiters() { + let cache = build_cache(); + let call_count = Arc::new(AtomicUsize::new(0)); + let getter_started = Arc::new(tokio::sync::Notify::new()); + + let cache1 = cache.clone(); + let count1 = call_count.clone(); + let started1 = getter_started.clone(); + let task1 = tokio::spawn(async move { + cache1 + .request() + .fetch_one_json( + "in_flight_multi", + "shared_key", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count1.clone(); + let started = started1.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + started.notify_one(); + tokio::time::sleep(Duration::from_millis(150)).await; + ctx.resolve(&key, "shared_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + getter_started.notified().await; + + let make_waiter = |cache: rivet_cache::Cache, count: Arc| { + tokio::spawn(async move { + cache + .request() + .fetch_one_json( + "in_flight_multi", + "shared_key", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count.clone(); + async move { + // Should not be reached. + count.fetch_add(1, Ordering::SeqCst); + ctx.resolve(&key, "wrong_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }) + }; + + let task2 = make_waiter(cache.clone(), call_count.clone()); + let task3 = make_waiter(cache.clone(), call_count.clone()); + + let (r1, r2, r3) = tokio::join!(task1, task2, task3); + + assert_eq!(r1.unwrap().unwrap(), Some("shared_value".to_string())); + assert_eq!(r2.unwrap().unwrap(), Some("shared_value".to_string())); + assert_eq!(r3.unwrap().unwrap(), Some("shared_value".to_string())); + assert_eq!( + call_count.load(Ordering::SeqCst), + 1, + "getter called exactly once for three concurrent requests" + ); +} + +/// Concurrent requests for different keys should not share in-flight state. +/// Each key's getter must be called independently. +#[tokio::test(flavor = "multi_thread")] +async fn independent_keys() { + let cache = build_cache(); + let call_count = Arc::new(AtomicUsize::new(0)); + + let make_task = |cache: rivet_cache::Cache, count: Arc, key: &'static str| { + tokio::spawn(async move { + cache + .request() + .fetch_one_json( + "in_flight_independent", + key, + move |mut ctx: rivet_cache::GetterCtx<&str, String>, k| { + let count = count.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + ctx.resolve(&k, format!("val_{k}")); + Ok(ctx) + } + }, + ) + .await + }) + }; + + let t1 = make_task(cache.clone(), call_count.clone(), "key_a"); + let t2 = make_task(cache.clone(), call_count.clone(), "key_b"); + + let (r1, r2) = tokio::join!(t1, t2); + + assert_eq!(r1.unwrap().unwrap(), Some("val_key_a".to_string())); + assert_eq!(r2.unwrap().unwrap(), Some("val_key_b".to_string())); + assert_eq!( + call_count.load(Ordering::SeqCst), + 2, + "getter called once per distinct key" + ); +} + +/// A batch request that mixes an already-cached key with an in-flight key. +/// The cached key resolves immediately; the in-flight key waits without the +/// getter being invoked again. +#[tokio::test(flavor = "multi_thread")] +async fn mixed_cached_and_in_flight() { + let cache = build_cache(); + + // Pre-populate "cached_key". + cache + .clone() + .request() + .fetch_one_json( + "in_flight_mixed", + "cached_key", + |mut ctx: rivet_cache::GetterCtx<&str, String>, key| async move { + ctx.resolve(&key, "cached_value".to_string()); + Ok(ctx) + }, + ) + .await + .unwrap(); + + let call_count = Arc::new(AtomicUsize::new(0)); + let getter_started = Arc::new(tokio::sync::Notify::new()); + + // Task 1: fetches only the slow key, holds the in-flight lease. + let cache1 = cache.clone(); + let count1 = call_count.clone(); + let started1 = getter_started.clone(); + let t1 = tokio::spawn(async move { + cache1 + .request() + .fetch_one_json( + "in_flight_mixed", + "slow_key", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count1.clone(); + let started = started1.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + started.notify_one(); + tokio::time::sleep(Duration::from_millis(100)).await; + ctx.resolve(&key, "slow_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + getter_started.notified().await; + + // Task 2: fetches both keys at once. "cached_key" is a cache hit and + // "slow_key" is in-flight, so neither should trigger the getter. + let count2 = call_count.clone(); + let t2 = tokio::spawn(async move { + cache + .request() + .fetch_all_json( + "in_flight_mixed", + vec!["cached_key", "slow_key"], + move |mut ctx: rivet_cache::GetterCtx<&str, String>, keys| { + let count = count2.clone(); + async move { + if !keys.is_empty() { + // Should not be reached for either key. + count.fetch_add(1, Ordering::SeqCst); + for k in &keys { + ctx.resolve(k, "wrong_value".to_string()); + } + } + Ok(ctx) + } + }, + ) + .await + }); + + let (r1, r2) = tokio::join!(t1, t2); + + assert_eq!(r1.unwrap().unwrap(), Some("slow_value".to_string())); + + let mut r2_vals = r2.unwrap().unwrap(); + r2_vals.sort(); + assert_eq!( + r2_vals, + vec!["cached_value".to_string(), "slow_value".to_string()] + ); + assert_eq!( + call_count.load(Ordering::SeqCst), + 1, + "only the leasing task's getter should be called" + ); +} + +/// When the leasing task's getter takes longer than IN_FLIGHT_TIMEOUT (5s), +/// the waiting task should stop waiting and fall back to calling its own getter. +#[tokio::test(flavor = "multi_thread")] +async fn timeout_falls_back_to_getter() { + let cache = build_cache(); + let call_count = Arc::new(AtomicUsize::new(0)); + let getter_started = Arc::new(tokio::sync::Notify::new()); + let getter_release = Arc::new(tokio::sync::Notify::new()); + + // Task 1: holds the in-flight lease for longer than IN_FLIGHT_TIMEOUT. + let cache1 = cache.clone(); + let count1 = call_count.clone(); + let started1 = getter_started.clone(); + let release1 = getter_release.clone(); + let task1 = tokio::spawn(async move { + cache1 + .request() + .fetch_one_json( + "timeout_ns", + "key1", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count1.clone(); + let started = started1.clone(); + let release = release1.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + started.notify_one(); + // Block until told to proceed, simulating a very slow getter. + release.notified().await; + ctx.resolve(&key, "task1_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + getter_started.notified().await; + + // Task 2: subscribes as a waiter and will time out after IN_FLIGHT_TIMEOUT. + let count2 = call_count.clone(); + let task2 = tokio::spawn(async move { + cache + .request() + .fetch_one_json( + "timeout_ns", + "key1", + move |mut ctx: rivet_cache::GetterCtx<&str, String>, key| { + let count = count2.clone(); + async move { + count.fetch_add(1, Ordering::SeqCst); + ctx.resolve(&key, "task2_value".to_string()); + Ok(ctx) + } + }, + ) + .await + }); + + // Wait for task2 to time out (IN_FLIGHT_TIMEOUT = 5s), then release task1. + // task2 should have already fallen back to its own getter by the time task1 finishes. + task2.await.unwrap().unwrap(); + getter_release.notify_one(); + task1.await.unwrap().unwrap(); + + assert_eq!( + call_count.load(Ordering::SeqCst), + 2, + "both getters should be called: task1 held the lease, task2 timed out and fetched itself" + ); +} diff --git a/engine/packages/cache/tests/integration.rs b/engine/packages/cache/tests/integration.rs deleted file mode 100644 index e2c11e0ed8..0000000000 --- a/engine/packages/cache/tests/integration.rs +++ /dev/null @@ -1,582 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, - time::Duration, -}; - -use rand::{Rng, seq::IteratorRandom, thread_rng}; - -async fn build_in_memory_cache() -> rivet_cache::Cache { - rivet_cache::CacheInner::new_in_memory("cache-test".to_owned(), 1000, None) -} - -async fn test_multiple_keys(cache: rivet_cache::Cache) { - let values = cache - .clone() - .request() - .fetch_all( - "multiple_keys", - vec!["a", "b", "c"], - |mut cache, keys| async move { - for key in &keys { - cache.resolve(key, format!("{0}{0}{0}", key)); - } - Ok(cache) - }, - ) - .await - .unwrap(); - assert_eq!(3, values.len(), "missing values"); - for (k, v) in values { - let expected_v = match k { - "a" => "aaa", - "b" => "bbb", - "c" => "ccc", - _ => panic!("unexpected key {}", k), - }; - assert_eq!(expected_v, v, "unexpected value"); - } -} - -async fn test_smoke_test(cache: rivet_cache::Cache) { - // Generate random entries for the cache - let mut entries = HashMap::new(); - for i in 0..16usize { - entries.insert(i.to_string(), format!("{0}{0}{0}", i)); - } - let entries = Arc::new(entries); - - let parallel_count = 32; // Reduced for faster tests - let barrier = Arc::new(tokio::sync::Barrier::new(parallel_count)); - let mut handles = Vec::new(); - for _ in 0..parallel_count { - let keys = - std::iter::repeat_with(|| entries.keys().choose(&mut thread_rng()).unwrap().clone()) - .take(thread_rng().gen_range(0..8)) - .collect::>(); - let deduplicated_keys = keys.clone().into_iter().collect::>(); - - let entries = entries.clone(); - let cache = cache.clone(); - let barrier = barrier.clone(); - let handle = tokio::spawn(async move { - barrier.wait().await; - let values = cache - .request() - .fetch_all("smoke_test", keys, move |mut cache, keys| { - let entries = entries.clone(); - async move { - // Reduced sleep for faster tests - tokio::time::sleep(Duration::from_millis(100)).await; - for key in &keys { - cache.resolve(key, entries.get(key).expect("invalid key").clone()); - } - Ok(cache) - } - }) - .await - .unwrap(); - assert_eq!( - deduplicated_keys, - values - .iter() - .map(|x| x.0.clone()) - .collect::>() - ); - }); - handles.push(handle); - } - futures_util::future::try_join_all(handles).await.unwrap(); -} - -/// Tests that a custom TTL is properly respected when setting and accessing items -async fn test_custom_ttl(cache: rivet_cache::Cache) { - let test_key = "ttl-test-key"; - let test_value = "test-value"; - let short_ttl_ms = 500i64; // 500ms TTL - - // Store with a custom short TTL - let _ = cache - .clone() - .request() - .ttl(short_ttl_ms) - .fetch_one("ttl_test", test_key, |mut cache, key| async move { - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }) - .await - .unwrap(); - - // Verify value exists immediately after storing - let value = cache - .clone() - .request() - .fetch_one( - "ttl_test", - test_key, - |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { - // If not found in cache, we need to return the same value - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }, - ) - .await - .unwrap(); - - assert_eq!( - Some(test_value.to_string()), - value, - "Value should be available before TTL expiration" - ); - - // Wait for the TTL to expire - use a longer wait to ensure it expires - tokio::time::sleep(Duration::from_millis((short_ttl_ms * 3) as u64)).await; - - // Since we want to test value expiration, manually purge for consistency across implementations - cache - .clone() - .request() - .purge("ttl_test", [test_key]) - .await - .unwrap(); - - // Verify value no longer exists after TTL expiration - let value = cache - .clone() - .request() - .fetch_one( - "ttl_test", - test_key, - |cache: rivet_cache::GetterCtx<&str, String>, _| async move { - // Don't resolve anything - we want to verify the key is gone - Ok(cache) - }, - ) - .await - .unwrap(); - - assert_eq!( - None, value, - "Value should not be available after TTL expiration" - ); -} - -/// Tests that default TTL is applied correctly when not explicitly specified -async fn test_default_ttl(cache: rivet_cache::Cache) { - let test_key = "default-ttl-key"; - let test_value = "default-value"; - - // Store with default TTL (should use 2 hours) - let _ = cache - .clone() - .request() - .fetch_one("default_ttl_test", test_key, |mut cache, key| async move { - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }) - .await - .unwrap(); - - // Verify value exists after storing - let value = cache - .clone() - .request() - .fetch_one( - "default_ttl_test", - test_key, - |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { - // If not found in cache, we need to return the same value - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }, - ) - .await - .unwrap(); - - assert_eq!( - Some(test_value.to_string()), - value, - "Value should be available with default TTL" - ); -} - -/// Tests that purging a key removes it regardless of TTL -async fn test_purge_with_ttl(cache: rivet_cache::Cache) { - let test_key = "purge-key"; - let test_value = "purge-value"; - let long_ttl_ms = 3600000i64; // 1 hour TTL - - // Store with a long TTL - let _ = cache - .clone() - .request() - .ttl(long_ttl_ms) - .fetch_one("purge_test", test_key, |mut cache, key| async move { - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }) - .await - .unwrap(); - - // Verify value exists after storing - let value = cache - .clone() - .request() - .fetch_one( - "purge_test", - test_key, - |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { - // If not found in cache, we need to return the same value - cache.resolve(&key, test_value.to_string()); - Ok(cache) - }, - ) - .await - .unwrap(); - - assert_eq!( - Some(test_value.to_string()), - value, - "Value should be available after storing" - ); - - // Purge the key - cache - .clone() - .request() - .purge("purge_test", [test_key]) - .await - .unwrap(); - - // Verify value no longer exists after purging - let value = cache - .clone() - .request() - .fetch_one( - "purge_test", - test_key, - |cache: rivet_cache::GetterCtx<&str, String>, _| async move { Ok(cache) }, - ) - .await - .unwrap(); - - assert_eq!(None, value, "Value should not be available after purging"); -} - -/// Tests multiple TTLs for different keys in the same batch -async fn test_multi_key_ttl(cache: rivet_cache::Cache) { - let short_ttl_key = "short-ttl"; - let long_ttl_key = "long-ttl"; - let short_ttl_ms = 500i64; // 500ms TTL - - // First, purge any existing keys to ensure clean state - cache - .clone() - .request() - .purge("multi_ttl_test", [short_ttl_key, long_ttl_key]) - .await - .unwrap(); - - // Create separate cache handlers with different TTLs - let short_ttl_cache = cache.clone().request().ttl(short_ttl_ms); - let long_ttl_cache = cache.clone().request().ttl(short_ttl_ms * 10); // 5 seconds - - // Store key with short TTL - let _ = short_ttl_cache - .clone() - .fetch_one( - "multi_ttl_test", - short_ttl_key, - |mut cache, key| async move { - cache.resolve(&key, "short".to_string()); - Ok(cache) - }, - ) - .await - .unwrap(); - - // Store key with long TTL - let _ = long_ttl_cache - .clone() - .fetch_one( - "multi_ttl_test", - long_ttl_key, - |mut cache, key| async move { - cache.resolve(&key, "long".to_string()); - Ok(cache) - }, - ) - .await - .unwrap(); - - // Verify both values exist immediately - let values = cache - .clone() - .request() - .fetch_all( - "multi_ttl_test", - vec![short_ttl_key, long_ttl_key], - |mut cache: rivet_cache::GetterCtx<&str, String>, keys| async move { - // If not found in cache, we need to return the values - for key in &keys { - if *key == short_ttl_key { - cache.resolve(key, "short".to_string()); - } else if *key == long_ttl_key { - cache.resolve(key, "long".to_string()); - } - } - Ok(cache) - }, - ) - .await - .unwrap(); - - assert_eq!(2, values.len(), "Both values should be available initially"); - - // Wait for short TTL to expire - tokio::time::sleep(Duration::from_millis((short_ttl_ms + 200) as u64)).await; - - // Or manually purge it to ensure test consistency - cache - .clone() - .request() - .purge("multi_ttl_test", [short_ttl_key]) - .await - .unwrap(); - - let short_value = cache - .clone() - .request() - .fetch_one( - "multi_ttl_test", - short_ttl_key, - |cache: rivet_cache::GetterCtx<&str, String>, _| async move { Ok(cache) }, - ) - .await - .unwrap(); - - assert_eq!(None, short_value, "Short TTL value should have expired"); - - // Check values after short TTL expiration - let values = cache - .clone() - .request() - .fetch_all( - "multi_ttl_test", - vec![short_ttl_key, long_ttl_key], - |mut cache: rivet_cache::GetterCtx<&str, String>, keys| async move { - // The short TTL key should have expired, so we regenerate it - // The long TTL key should still be in the cache - for key in &keys { - if *key == short_ttl_key { - cache.resolve(key, "regenerated".to_string()); - } - // For the long key, we still may need to resolve if not found in cache - else if *key == long_ttl_key { - cache.resolve(key, "long".to_string()); - } - } - Ok(cache) - }, - ) - .await - .unwrap(); - - // Convert to a map for easier assertion - let values_map: HashMap<_, _> = values.into_iter().collect(); - - assert_eq!(2, values_map.len(), "Both keys should be in result"); - assert_eq!( - Some(&"regenerated".to_string()), - values_map.get(short_ttl_key), - "Short TTL key should have regenerated value" - ); - assert_eq!( - Some(&"long".to_string()), - values_map.get(long_ttl_key), - "Long TTL key should still have original value" - ); -} - -/// Tests basic rate limiting functionality -async fn test_rate_limit_basic(cache: rivet_cache::Cache) { - // Define a simple cache key for testing - #[derive(Debug, Clone, PartialEq)] - struct TestKey; - - impl rivet_cache::CacheKey for TestKey { - fn cache_key(&self) -> String { - "rate-limit-test".to_string() - } - } - - for _i in 0..5 { - let config = rivet_cache::RateLimitConfig { - key: "test_rate_limit".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 5, // Allow 5 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, "127.0.0.1", config).await; - assert_eq!(1, result.len()); - assert!(result[0].is_valid, "Request should be valid"); - } - - // Sixth request should not be valid (exceeds the limit of 5) - let config = rivet_cache::RateLimitConfig { - key: "test_rate_limit".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 5, // Allow 5 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, "127.0.0.1", config).await; - assert_eq!(1, result.len()); - assert!(!result[0].is_valid, "Sixth request should not be valid"); -} - -/// Tests that rate limits are properly isolated by IP address -async fn test_rate_limit_ip_isolation(cache: rivet_cache::Cache) { - // Define a simple cache key for testing - #[derive(Debug, Clone, PartialEq)] - struct TestKey; - - impl rivet_cache::CacheKey for TestKey { - fn cache_key(&self) -> String { - "ip-isolation-test".to_string() - } - } - - // IP addresses to test - let ip1 = "192.168.0.1"; - let ip2 = "10.0.0.1"; - - // Make multiple requests from IP1 - for _i in 0..3 { - let config = rivet_cache::RateLimitConfig { - key: "test_ip_isolation".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 3, // Allow 3 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, ip1, config).await; - assert_eq!(1, result.len()); - assert!(result[0].is_valid, "Request from IP1 should be valid"); - } - - // Next request from IP1 should exceed the limit - let config = rivet_cache::RateLimitConfig { - key: "test_ip_isolation".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 3, // Allow 3 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, ip1, config).await; - assert_eq!(1, result.len()); - assert!( - !result[0].is_valid, - "Fourth request from IP1 should not be valid" - ); - - // Requests from IP2 should still be valid even though IP1 is blocked - for _i in 0..3 { - let config = rivet_cache::RateLimitConfig { - key: "test_ip_isolation".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 3, // Allow 3 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, ip2, config).await; - assert_eq!(1, result.len()); - assert!(result[0].is_valid, "Request from IP2 should be valid"); - } - - // Next request from IP2 should exceed the limit - let config = rivet_cache::RateLimitConfig { - key: "test_ip_isolation".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 3, // Allow 3 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, ip2, config).await; - assert_eq!(1, result.len()); - assert!( - !result[0].is_valid, - "Fourth request from IP2 should not be valid" - ); - - // IP1 should still be blocked (testing key persistence) - let config = rivet_cache::RateLimitConfig { - key: "test_ip_isolation".to_string(), - buckets: vec![rivet_cache::RateLimitBucketConfig { - count: 3, // Allow 3 requests per minute - bucket_duration_ms: 60000, // 1 minute (60,000 ms) - }], - }; - - let result = cache.rate_limit(&TestKey, ip1, config).await; - assert_eq!(1, result.len()); - assert!( - !result[0].is_valid, - "IP1 should still be blocked after IP2 requests" - ); -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_multiple_keys() { - let cache = build_in_memory_cache().await; - test_multiple_keys(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_smoke_test() { - let cache = build_in_memory_cache().await; - test_smoke_test(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_custom_ttl() { - let cache = build_in_memory_cache().await; - test_custom_ttl(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_default_ttl() { - let cache = build_in_memory_cache().await; - test_default_ttl(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_purge_with_ttl() { - let cache = build_in_memory_cache().await; - test_purge_with_ttl(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_multi_key_ttl() { - let cache = build_in_memory_cache().await; - test_multi_key_ttl(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_rate_limit_basic() { - let cache = build_in_memory_cache().await; - test_rate_limit_basic(cache).await; -} - -#[tokio::test(flavor = "multi_thread")] -async fn in_memory_rate_limit_ip_isolation() { - let cache = build_in_memory_cache().await; - test_rate_limit_ip_isolation(cache).await; -} diff --git a/engine/packages/cache/tests/ttl.rs b/engine/packages/cache/tests/ttl.rs new file mode 100644 index 0000000000..b1bc3438f1 --- /dev/null +++ b/engine/packages/cache/tests/ttl.rs @@ -0,0 +1,314 @@ +use std::{collections::HashMap, time::Duration}; + +fn build_cache() -> rivet_cache::Cache { + rivet_cache::CacheInner::new_in_memory(1000, None) +} + +/// Tests that a custom TTL is properly respected when setting and accessing items +#[tokio::test(flavor = "multi_thread")] +async fn custom_ttl() { + let cache = build_cache(); + let test_key = "ttl-test-key"; + let test_value = "test-value"; + let short_ttl_ms = 500i64; // 500ms TTL + + // Store with a custom short TTL + let _ = cache + .clone() + .request() + .ttl(short_ttl_ms) + .fetch_one_json("ttl_test", test_key, |mut cache, key| async move { + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }) + .await + .unwrap(); + + // Verify value exists immediately after storing + let value = cache + .clone() + .request() + .fetch_one_json( + "ttl_test", + test_key, + |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { + // If not found in cache, we need to return the same value + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!( + Some(test_value.to_string()), + value, + "value should be available before TTL expiration" + ); + + // Wait for the TTL to expire - use a longer wait to ensure it expires + tokio::time::sleep(Duration::from_millis((short_ttl_ms * 3) as u64)).await; + + // Since we want to test value expiration, manually purge for consistency across implementations + cache + .clone() + .request() + .purge("ttl_test", [test_key]) + .await + .unwrap(); + + // Verify value no longer exists after TTL expiration + let value = cache + .clone() + .request() + .fetch_one_json( + "ttl_test", + test_key, + |cache: rivet_cache::GetterCtx<&str, String>, _| async move { + // Don't resolve anything - we want to verify the key is gone + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!( + None, value, + "value should not be available after TTL expiration" + ); +} + +/// Tests that default TTL is applied correctly when not explicitly specified +#[tokio::test(flavor = "multi_thread")] +async fn default_ttl() { + let cache = build_cache(); + let test_key = "default-ttl-key"; + let test_value = "default-value"; + + // Store with default TTL (should use 2 hours) + let _ = cache + .clone() + .request() + .fetch_one_json("default_ttl_test", test_key, |mut cache, key| async move { + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }) + .await + .unwrap(); + + // Verify value exists after storing + let value = cache + .clone() + .request() + .fetch_one_json( + "default_ttl_test", + test_key, + |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { + // If not found in cache, we need to return the same value + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!( + Some(test_value.to_string()), + value, + "value should be available with default TTL" + ); +} + +/// Tests that purging a key removes it regardless of TTL +#[tokio::test(flavor = "multi_thread")] +async fn purge_with_ttl() { + let cache = build_cache(); + let test_key = "purge-key"; + let test_value = "purge-value"; + let long_ttl_ms = 3600000i64; // 1 hour TTL + + // Store with a long TTL + let _ = cache + .clone() + .request() + .ttl(long_ttl_ms) + .fetch_one_json("purge_test", test_key, |mut cache, key| async move { + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }) + .await + .unwrap(); + + // Verify value exists after storing + let value = cache + .clone() + .request() + .fetch_one_json( + "purge_test", + test_key, + |mut cache: rivet_cache::GetterCtx<&str, String>, key| async move { + // If not found in cache, we need to return the same value + cache.resolve(&key, test_value.to_string()); + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!( + Some(test_value.to_string()), + value, + "value should be available after storing" + ); + + // Purge the key + cache + .clone() + .request() + .purge("purge_test", [test_key]) + .await + .unwrap(); + + // Verify value no longer exists after purging + let value = cache + .clone() + .request() + .fetch_one_json( + "purge_test", + test_key, + |cache: rivet_cache::GetterCtx<&str, String>, _| async move { Ok(cache) }, + ) + .await + .unwrap(); + assert_eq!(None, value, "value should not be available after purging"); +} + +/// Tests multiple TTLs for different keys in the same batch +#[tokio::test(flavor = "multi_thread")] +async fn multi_key_ttl() { + let cache = build_cache(); + let short_ttl_key = "short-ttl"; + let long_ttl_key = "long-ttl"; + let short_ttl_ms = 500i64; // 500ms TTL + + // First, purge any existing keys to ensure clean state + cache + .clone() + .request() + .purge("multi_ttl_test", [short_ttl_key, long_ttl_key]) + .await + .unwrap(); + + // Store key with short TTL + let _ = cache + .clone() + .request() + .ttl(short_ttl_ms) + .fetch_one_json( + "multi_ttl_test", + short_ttl_key, + |mut cache, key| async move { + cache.resolve(&key, "short".to_string()); + Ok(cache) + }, + ) + .await + .unwrap(); + + // Store key with long TTL + let _ = cache + .clone() + .request() + .ttl(short_ttl_ms * 10) // 5 seconds + .fetch_one_json( + "multi_ttl_test", + long_ttl_key, + |mut cache, key| async move { + cache.resolve(&key, "long".to_string()); + Ok(cache) + }, + ) + .await + .unwrap(); + + // Verify both values exist immediately + let values = cache + .clone() + .request() + .fetch_all_json_with_keys( + "multi_ttl_test", + vec![short_ttl_key, long_ttl_key], + |mut cache: rivet_cache::GetterCtx<&str, String>, keys| async move { + // If not found in cache, we need to return the values + for key in &keys { + if *key == short_ttl_key { + cache.resolve(key, "short".to_string()); + } else if *key == long_ttl_key { + cache.resolve(key, "long".to_string()); + } + } + Ok(cache) + }, + ) + .await + .unwrap(); + assert_eq!(2, values.len(), "both values should be available initially"); + + // Wait for short TTL to expire + tokio::time::sleep(Duration::from_millis((short_ttl_ms + 200) as u64)).await; + + // Or manually purge it to ensure test consistency + cache + .clone() + .request() + .purge("multi_ttl_test", [short_ttl_key]) + .await + .unwrap(); + + let short_value = cache + .clone() + .request() + .fetch_one_json( + "multi_ttl_test", + short_ttl_key, + |cache: rivet_cache::GetterCtx<&str, String>, _| async move { Ok(cache) }, + ) + .await + .unwrap(); + assert_eq!(None, short_value, "short TTL value should have expired"); + + // Check values after short TTL expiration + let values = cache + .clone() + .request() + .fetch_all_json_with_keys( + "multi_ttl_test", + vec![short_ttl_key, long_ttl_key], + |mut cache: rivet_cache::GetterCtx<&str, String>, keys| async move { + // The short TTL key should have expired, so we regenerate it. + // The long TTL key should still be in the cache. + for key in &keys { + if *key == short_ttl_key { + cache.resolve(key, "regenerated".to_string()); + } else if *key == long_ttl_key { + // For the long key, we still may need to resolve if not found in cache. + cache.resolve(key, "long".to_string()); + } + } + Ok(cache) + }, + ) + .await + .unwrap(); + + // Convert to a map for easier assertion + let values_map: HashMap<_, _> = values.into_iter().collect(); + + assert_eq!(2, values_map.len(), "both keys should be in result"); + assert_eq!( + Some(&"regenerated".to_string()), + values_map.get(short_ttl_key), + "short TTL key should have regenerated value" + ); + assert_eq!( + Some(&"long".to_string()), + values_map.get(long_ttl_key), + "long TTL key should still have original value" + ); +}