Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No-op LRU impl for non-lru tracked functions #664

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ macro_rules! setup_tracked_fn {
needs_interner: $needs_interner:tt,

// LRU capacity (a literal, maybe 0)
lru: $lru:tt,
lru_capacity: $lru_capacity:tt,
has_lru: $has_lru:tt,

// True if we `return_ref` flag was given to the function
return_ref: $return_ref:tt,
Expand Down Expand Up @@ -156,6 +157,14 @@ macro_rules! setup_tracked_fn {

type Output<$db_lt> = $output_ty;

type Lru = $zalsa::macro_if! {
if $has_lru {
$zalsa::lru::Lru
} else {
$zalsa::lru::NoLru
}
};

const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy;

fn should_backdate_value(
Expand Down Expand Up @@ -218,7 +227,12 @@ macro_rules! setup_tracked_fn {
first_index,
aux,
);
fn_ingredient.set_capacity($lru);
$zalsa::macro_if! {
if $has_lru {
fn_ingredient.set_capacity($lru_capacity);
} else {
}
}
$zalsa::macro_if! {
if $needs_interner {
vec![
Expand Down Expand Up @@ -273,12 +287,12 @@ macro_rules! setup_tracked_fn {
}
}

$zalsa::macro_if! { if0 $lru { } else {
$zalsa::macro_if! { if $has_lru {
#[allow(dead_code)]
fn set_lru_capacity(db: &dyn $Db, value: usize) {
$Configuration::fn_ingredient(db).set_capacity(value);
}
} }
} else {} }
}

$zalsa::attach($db, || {
Expand Down
9 changes: 7 additions & 2 deletions components/salsa-macros/src/tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ impl Macro {
FunctionType::SalsaStruct => false,
};

let lru = Literal::usize_unsuffixed(self.args.lru.unwrap_or(0));
let (has_lru, lru_capacity) = match self.args.lru {
Some(cap) => (true, cap),
None => (false, 0),
};
let lru_capacity = Literal::usize_unsuffixed(lru_capacity);

let return_ref: bool = self.args.return_ref.is_some();

Expand All @@ -131,7 +135,8 @@ impl Macro {
is_specifiable: #is_specifiable,
no_eq: #no_eq,
needs_interner: #needs_interner,
lru: #lru,
lru_capacity: #lru_capacity,
has_lru: #has_lru,
return_ref: #return_ref,
unused_names: [
#zalsa,
Expand Down
3 changes: 2 additions & 1 deletion src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
cycle::CycleRecoveryStrategy,
ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter},
plumbing::JarAux,
table::memo::MemoTable,
zalsa::IngredientIndex,
zalsa_local::QueryOrigin,
Database, DatabaseKeyIndex, Id, Revision,
Expand Down Expand Up @@ -137,7 +138,7 @@ impl<A: Accumulator> Ingredient for IngredientImpl<A> {
false
}

fn reset_for_new_revision(&mut self) {
fn reset_for_new_revision<'a>(&mut self, _: &'a dyn Fn(crate::Id) -> &'a MemoTable) {
panic!("unexpected reset on accumulator")
}

Expand Down
25 changes: 18 additions & 7 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::{any::Any, fmt, sync::Arc};
use crate::{
accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues},
cycle::CycleRecoveryStrategy,
function::lru::LruChoice,
ingredient::{fmt_index, MaybeChangedAfter},
key::DatabaseKeyIndex,
plumbing::JarAux,
salsa_struct::SalsaStructInDb,
table::memo::MemoTable,
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
zalsa_local::QueryOrigin,
Cycle, Database, Id, Revision,
Expand All @@ -23,7 +25,7 @@ mod diff_outputs;
mod execute;
mod fetch;
mod inputs;
mod lru;
pub mod lru;
mod maybe_changed_after;
mod memo;
mod specify;
Expand All @@ -45,6 +47,9 @@ pub trait Configuration: Any {
/// The value computed by the function.
type Output<'db>: fmt::Debug + Send + Sync;

/// The singleton state for this input if any.
type Lru: LruChoice + Send + Sync;

/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
const CYCLE_STRATEGY: CycleRecoveryStrategy;
Expand Down Expand Up @@ -98,7 +103,7 @@ pub struct IngredientImpl<C: Configuration> {
memo_ingredient_index: MemoIngredientIndex,

/// Used to find memos to throw out when we have too many memoized values.
lru: lru::Lru,
lru: C::Lru,

/// When `fetch` and friends executes, they return a reference to the
/// value stored in the memo that is extended to live as long as the `&self`
Expand Down Expand Up @@ -155,17 +160,17 @@ where
/// only cleared with `&mut self`.
unsafe fn extend_memo_lifetime<'this>(
&'this self,
memo: &memo::Memo<C::Output<'this>>,
) -> &'this memo::Memo<C::Output<'this>> {
memo: &memo::MemoConfigured<'this, C>,
) -> &'this memo::MemoConfigured<'this, C> {
std::mem::transmute(memo)
}

fn insert_memo<'db>(
&'db self,
zalsa: &'db Zalsa,
id: Id,
memo: memo::Memo<C::Output<'db>>,
) -> &'db memo::Memo<C::Output<'db>> {
memo: memo::MemoConfigured<'db, C>,
) -> &'db memo::MemoConfigured<'db, C> {
let memo = Arc::new(memo);
let db_memo = unsafe {
// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
Expand All @@ -184,6 +189,7 @@ where
impl<C> Ingredient for IngredientImpl<C>
where
C: Configuration,
for<'lt> <C::Lru as LruChoice>::LruCtor<<C as Configuration>::Output<'lt>>: Send + Sync,
{
fn ingredient_index(&self) -> IngredientIndex {
self.index
Expand Down Expand Up @@ -231,8 +237,13 @@ where
true
}

fn reset_for_new_revision(&mut self) {
fn reset_for_new_revision<'a>(
&mut self,
memo_table_for: &'a dyn Fn(crate::Id) -> &'a MemoTable,
) {
std::mem::take(&mut self.deleted_entries);
self.lru
.to_be_evicted(|evict| self.evict_value_from_memo_for(memo_table_for(evict)));
}

fn fmt_index(&self, index: Option<crate::Id>, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down
10 changes: 5 additions & 5 deletions src/function/backdate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::zalsa_local::QueryRevisions;
use crate::{function::memo::MemoConfigured, zalsa_local::QueryRevisions};

use super::{memo::Memo, Configuration, IngredientImpl};
use super::{Configuration, IngredientImpl, LruChoice};

impl<C> IngredientImpl<C>
where
Expand All @@ -11,11 +11,11 @@ where
/// on an old memo when a new memo has been produced to check whether there have been changed.
pub(super) fn backdate_if_appropriate(
&self,
old_memo: &Memo<C::Output<'_>>,
old_memo: &MemoConfigured<'_, C>,
revisions: &mut QueryRevisions,
value: &C::Output<'_>,
) {
if let Some(old_value) = &old_memo.value {
C::Lru::with_value(&old_memo.value, |old_value| {
// Careful: if the value became less durable than it
// used to be, that is a "breaking change" that our
// consumers must be aware of. Becoming *more* durable
Expand All @@ -31,6 +31,6 @@ where
assert!(old_memo.revisions.changed_at <= revisions.changed_at);
revisions.changed_at = old_memo.revisions.changed_at;
}
}
})
}
}
8 changes: 4 additions & 4 deletions src/function/diff_outputs.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{memo::Memo, Configuration, IngredientImpl};
use super::{Configuration, IngredientImpl};
use crate::{
hash::FxHashSet, key::OutputDependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _,
DatabaseKeyIndex, Event, EventKind,
function::memo::MemoConfigured, hash::FxHashSet, key::OutputDependencyIndex,
zalsa_local::QueryRevisions, AsDynDatabase as _, DatabaseKeyIndex, Event, EventKind,
};

impl<C> IngredientImpl<C>
Expand All @@ -17,7 +17,7 @@ where
&self,
db: &C::DbView,
key: DatabaseKeyIndex,
old_memo: &Memo<C::Output<'_>>,
old_memo: &MemoConfigured<'_, C>,
revisions: &mut QueryRevisions,
) {
// Iterate over the outputs of the `old_memo` and put them into a hashset
Expand Down
15 changes: 10 additions & 5 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::sync::Arc;

use crate::{
zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind,
function::memo::MemoConfigured, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle,
Database, Event, EventKind,
};

use super::{memo::Memo, Configuration, IngredientImpl};
use super::{memo::Memo, Configuration, IngredientImpl, LruChoice};

impl<C> IngredientImpl<C>
where
Expand All @@ -23,8 +24,8 @@ where
&'db self,
db: &'db C::DbView,
active_query: ActiveQueryGuard<'_>,
opt_old_memo: Option<Arc<Memo<C::Output<'_>>>>,
) -> &'db Memo<C::Output<'db>> {
opt_old_memo: Option<Arc<MemoConfigured<'_, C>>>,
) -> &'db MemoConfigured<'db, C> {
let zalsa = db.zalsa();
let revision_now = zalsa.current_revision();
let database_key_index = active_query.database_key_index;
Expand Down Expand Up @@ -84,6 +85,10 @@ where

tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}");

self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions))
self.insert_memo(
zalsa,
id,
Memo::new(C::Lru::make_value(value), revision_now, revisions),
)
}
}
37 changes: 24 additions & 13 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
use super::{memo::Memo, Configuration, IngredientImpl};
use super::{Configuration, IngredientImpl};
use crate::{
accumulator::accumulated_map::InputAccumulatedValues, runtime::StampedValue,
zalsa::ZalsaDatabase, AsDynDatabase as _, Id,
accumulator::accumulated_map::InputAccumulatedValues, function::lru::LruChoice as _,
function::memo::MemoConfigured, runtime::StampedValue, zalsa::ZalsaDatabase,
AsDynDatabase as _, Id,
};

impl<C> IngredientImpl<C>
where
C: Configuration,
{
pub fn fetch<'db>(&'db self, db: &'db C::DbView, id: Id) -> &'db C::Output<'db> {
let (zalsa, zalsa_local) = db.zalsas();
let zalsa_local = db.zalsa_local();
zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database());

let memo = self.refresh_memo(db, id);
let StampedValue {
value,
durability,
changed_at,
} = memo.revisions.stamped_value(memo.value.as_ref().unwrap());
} = memo
.revisions
.stamped_value(C::Lru::assert_ref(&memo.value));

if let Some(evicted) = self.lru.record_use(id) {
self.evict_value_from_memo_for(zalsa, evicted);
}
self.lru.record_use(id);

zalsa_local.report_tracked_read(
self.database_key_index(id).into(),
Expand All @@ -41,7 +42,7 @@ where
&'db self,
db: &'db C::DbView,
id: Id,
) -> &'db Memo<C::Output<'db>> {
) -> &'db MemoConfigured<'db, C> {
loop {
if let Some(memo) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) {
return memo;
Expand All @@ -50,11 +51,15 @@ where
}

#[inline]
fn fetch_hot<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo<C::Output<'db>>> {
fn fetch_hot<'db>(
&'db self,
db: &'db C::DbView,
id: Id,
) -> Option<&'db MemoConfigured<'db, C>> {
let zalsa = db.zalsa();
let memo_guard = self.get_memo_from_table_for(zalsa, id);
if let Some(memo) = &memo_guard {
if memo.value.is_some()
if !C::Lru::is_evicted(&memo.value)
&& self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo)
{
// Unsafety invariant: memo is present in memo_map and we have verified that it is
Expand All @@ -65,7 +70,11 @@ where
None
}

fn fetch_cold<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo<C::Output<'db>>> {
fn fetch_cold<'db>(
&'db self,
db: &'db C::DbView,
id: Id,
) -> Option<&'db MemoConfigured<'db, C>> {
let (zalsa, zalsa_local) = db.zalsas();
let database_key_index = self.database_key_index(id);

Expand All @@ -84,7 +93,9 @@ where
let zalsa = db.zalsa();
let opt_old_memo = self.get_memo_from_table_for(zalsa, id);
if let Some(old_memo) = &opt_old_memo {
if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) {
if !C::Lru::is_evicted(&old_memo.value)
&& self.deep_verify_memo(db, old_memo, &active_query)
{
// Unsafety invariant: memo is present in memo_map and we have verified that it is
// still valid for the current revision.
return unsafe { Some(self.extend_memo_lifetime(old_memo)) };
Expand Down
Loading
Loading