Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion engine/artifacts/config-schema.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

174 changes: 28 additions & 146 deletions engine/packages/cache/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ use std::{
};

use moka::future::{Cache, CacheBuilder};
use serde::{Serialize, de::DeserializeOwned};
use tracing::Instrument;

use crate::{errors::Error, metrics};
use crate::{RawCacheKey, errors::Error};

/// Type alias for cache values stored as bytes
pub type CacheValue = Vec<u8>;

/// Enum wrapper for different cache driver implementations
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Driver {
InMemory(InMemoryDriver),
Expand All @@ -22,84 +20,49 @@ pub enum Driver {
impl Driver {
/// Fetch multiple values from cache at once
#[tracing::instrument(skip_all, fields(driver=%self))]
pub async fn fetch_values<'a>(
pub async fn get<'a>(
&'a self,
base_key: &'a str,
keys: &[String],
keys: Vec<RawCacheKey>,
) -> Result<Vec<Option<CacheValue>>, Error> {
match self {
Driver::InMemory(d) => d.fetch_values(base_key, keys).await,
Driver::InMemory(d) => d.get(base_key, keys).await,
}
}

/// Set multiple values in cache at once
#[tracing::instrument(skip_all, fields(driver=%self))]
pub async fn set_values<'a>(
pub async fn set<'a>(
&'a self,
base_key: &'a str,
keys_values: Vec<(String, CacheValue, i64)>,
keys_values: Vec<(RawCacheKey, CacheValue, i64)>,
) -> Result<(), Error> {
match self {
Driver::InMemory(d) => d.set_values(base_key, keys_values).await,
Driver::InMemory(d) => d.set(base_key, keys_values).await,
}
}

/// Delete multiple keys from cache
#[tracing::instrument(skip_all, fields(driver=%self))]
pub async fn delete_keys<'a>(
pub async fn delete<'a>(
&'a self,
base_key: &'a str,
keys: Vec<String>,
keys: Vec<RawCacheKey>,
) -> Result<(), Error> {
match self {
Driver::InMemory(d) => d.delete_keys(base_key, keys).await,
Driver::InMemory(d) => d.delete(base_key, keys).await,
}
}

/// Process a raw key into a driver-specific format
///
/// Different implementations use different key formats:
/// - In-memory uses simpler keys
pub fn process_key(&self, base_key: &str, key: &impl crate::CacheKey) -> String {
pub fn process_key(&self, base_key: &str, key: &impl crate::CacheKey) -> RawCacheKey {
match self {
Driver::InMemory(d) => d.process_key(base_key, key),
}
}

/// Process a rate limit key into a driver-specific format
pub fn process_rate_limit_key(
&self,
key: &impl crate::CacheKey,
remote_address: impl AsRef<str>,
bucket: i64,
bucket_duration_ms: i64,
) -> String {
match self {
Driver::InMemory(d) => {
d.process_rate_limit_key(key, remote_address, bucket, bucket_duration_ms)
}
}
}

/// Increment a rate limit counter and return the new count
#[tracing::instrument(skip_all, fields(driver=%self))]
pub async fn rate_limit_increment<'a>(
&'a self,
key: &'a str,
ttl_ms: i64,
) -> Result<i64, Error> {
match self {
Driver::InMemory(d) => d.rate_limit_increment(key, ttl_ms).await,
}
}

pub fn encode_value<T: Serialize>(&self, value: &T) -> Result<CacheValue, Error> {
serde_json::to_vec(value).map_err(Error::SerdeEncode)
}

pub fn decode_value<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T, Error> {
serde_json::from_slice(bytes).map_err(Error::SerdeDecode)
}
}

impl std::fmt::Display for Driver {
Expand All @@ -120,7 +83,7 @@ struct ExpiringValue {
}

/// Cache expiry implementation for Moka
#[derive(Clone, Debug)]
#[derive(Debug)]
struct ValueExpiry;

impl moka::Expiry<String, ExpiringValue> for ValueExpiry {
Expand Down Expand Up @@ -170,55 +133,37 @@ impl moka::Expiry<String, ExpiringValue> for ValueExpiry {
}

/// In-memory cache driver implementation using the moka crate
#[derive(Clone)]
pub struct InMemoryDriver {
service_name: String,
cache: Cache<String, ExpiringValue>,
/// In-memory rate limiting store - maps keys to hit counts with expiration
rate_limits: Cache<String, ExpiringValue>,
}

impl Debug for InMemoryDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InMemoryDriver")
.field("service_name", &self.service_name)
.finish()
f.debug_struct("InMemoryDriver").finish()
}
}

impl InMemoryDriver {
pub fn new(service_name: String, max_capacity: u64) -> Self {
pub fn new(max_capacity: u64) -> Self {
// Create a cache with ValueExpiry implementation for custom expiration times
let cache = CacheBuilder::new(max_capacity)
.expire_after(ValueExpiry)
.build();

// Create a separate cache for rate limiting with the same expiry mechanism
let rate_limits = CacheBuilder::new(max_capacity)
.expire_after(ValueExpiry)
.build();

Self {
service_name,
cache,
rate_limits,
}
Self { cache }
}

pub async fn fetch_values<'a>(
pub async fn get<'a>(
&'a self,
_base_key: &'a str,
keys: &[String],
keys: Vec<RawCacheKey>,
) -> Result<Vec<Option<CacheValue>>, Error> {
let keys = keys.to_vec();
let cache = self.cache.clone();

let mut result = Vec::with_capacity(keys.len());

// Async block for metrics
async {
for key in keys {
result.push(cache.get(&key).await.map(|x| x.value.clone()));
result.push(self.cache.get(&*key).await.map(|x| x.value.clone()));
}
}
.instrument(tracing::info_span!("get"))
Expand All @@ -233,13 +178,11 @@ impl InMemoryDriver {
Ok(result)
}

pub async fn set_values<'a>(
pub async fn set<'a>(
&'a self,
_base_key: &'a str,
keys_values: Vec<(String, CacheValue, i64)>,
keys_values: Vec<(RawCacheKey, CacheValue, i64)>,
) -> Result<(), Error> {
let cache = self.cache.clone();

// Async block for metrics
async {
for (key, value, expire_at) in keys_values {
Expand All @@ -250,7 +193,7 @@ impl InMemoryDriver {
};

// Store in cache - expiry will be handled by ValueExpiry
cache.insert(key, entry).await;
self.cache.insert(key.into(), entry).await;
}
}
.instrument(tracing::info_span!("set"))
Expand All @@ -260,87 +203,26 @@ impl InMemoryDriver {
Ok(())
}

pub async fn delete_keys<'a>(
pub async fn delete<'a>(
&'a self,
base_key: &'a str,
keys: Vec<String>,
_base_key: &'a str,
keys: Vec<RawCacheKey>,
) -> Result<(), Error> {
let cache = self.cache.clone();

metrics::CACHE_PURGE_REQUEST_TOTAL
.with_label_values(&[base_key])
.inc();
metrics::CACHE_PURGE_VALUE_TOTAL
.with_label_values(&[base_key])
.inc_by(keys.len() as u64);

// Async block for metrics
async {
for key in keys {
// Use remove instead of invalidate to ensure it's actually removed
cache.remove(&key).await;
self.cache.remove(&*key).await;
}
}
.instrument(tracing::info_span!("remove"))
.instrument(tracing::info_span!("delete"))
.await;

tracing::trace!("successfully deleted keys from in-memory cache");
Ok(())
}

pub fn process_key(&self, base_key: &str, key: &impl crate::CacheKey) -> String {
format!("{}:{}", base_key, key.cache_key())
}

pub fn process_rate_limit_key(
&self,
key: &impl crate::CacheKey,
remote_address: impl AsRef<str>,
bucket: i64,
bucket_duration_ms: i64,
) -> String {
format!(
"rate_limit:{}:{}:{}:{}",
key.cache_key(),
remote_address.as_ref(),
bucket_duration_ms,
bucket,
)
}

/// Increment a rate limit counter for in-memory storage
pub async fn rate_limit_increment<'a>(
&'a self,
key: &'a str,
ttl_ms: i64,
) -> Result<i64, Error> {
let rate_limits = self.rate_limits.clone();

// Get current value or default to 0
let current_value = match rate_limits.get(key).await {
Some(value) => {
// Try to decode the value as an integer
match serde_json::from_slice::<i64>(&value.value).map_err(Error::SerdeDecode) {
Ok(count) => count,
Err(_) => 0, // If we can't decode, reset to 0
}
}
None => 0,
};

// Increment the value
let new_value = current_value + 1;
let encoded = serde_json::to_vec(&new_value).map_err(Error::SerdeEncode)?;

// Store with expiration
let entry = ExpiringValue {
value: encoded,
expiry_time: rivet_util::timestamp::now() + ttl_ms,
};

// Update the rate limit cache
rate_limits.insert(key.to_string(), entry).await;

Ok(new_value)
pub fn process_key(&self, base_key: &str, key: &impl crate::CacheKey) -> RawCacheKey {
RawCacheKey::from(format!("{}:{}", base_key, key.cache_key()))
}
}
25 changes: 5 additions & 20 deletions engine/packages/cache/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ pub type Cache = Arc<CacheInner>;

/// Utility type used to hold information relating to caching.
pub struct CacheInner {
service_name: String,
pub(crate) driver: Driver,
pub(crate) ups: Option<universalpubsub::PubSub>,
}

impl Debug for CacheInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CacheInner")
.field("service_name", &self.service_name)
.finish()
f.debug_struct("CacheInner").finish()
}
}

Expand All @@ -26,29 +23,17 @@ impl CacheInner {
config: &rivet_config::Config,
pools: rivet_pools::Pools,
) -> Result<Cache, Error> {
let service_name = rivet_env::service_name();
let ups = pools.ups().ok();

match &config.cache().driver {
rivet_config::config::CacheDriver::Redis => todo!(),
rivet_config::config::CacheDriver::InMemory => {
Ok(Self::new_in_memory(service_name.to_string(), 1000, ups))
}
rivet_config::config::CacheDriver::InMemory => Ok(Self::new_in_memory(10000, ups)),
}
}

#[tracing::instrument(skip(ups))]
pub fn new_in_memory(
service_name: String,
max_capacity: u64,
ups: Option<universalpubsub::PubSub>,
) -> Cache {
let driver = Driver::InMemory(InMemoryDriver::new(service_name.clone(), max_capacity));
Arc::new(CacheInner {
service_name,
driver,
ups,
})
pub fn new_in_memory(max_capacity: u64, ups: Option<universalpubsub::PubSub>) -> Cache {
let driver = Driver::InMemory(InMemoryDriver::new(max_capacity));
Arc::new(CacheInner { driver, ups })
}
}

Expand Down
16 changes: 15 additions & 1 deletion engine/packages/cache/src/key.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::{fmt::Debug, ops::Deref};

/// A type that can be serialized in to a key that can be used in the cache.
pub trait CacheKey: Clone + Debug + PartialEq {
Expand Down Expand Up @@ -99,3 +99,17 @@ impl From<String> for RawCacheKey {
RawCacheKey(value)
}
}

impl From<RawCacheKey> for String {
fn from(value: RawCacheKey) -> Self {
value.0
}
}

impl Deref for RawCacheKey {
type Target = String;

fn deref(&self) -> &Self::Target {
&self.0
}
}
Loading
Loading