Skip to content

Commit 56cad93

Browse files
Merge pull request #547 from theseus-rs/refactor-callsitecache
refactor: update CallSiteCache to use a DashMap instead of RwLock<HashMap>
2 parents 1045cbe + 5332575 commit 56cad93

File tree

5 files changed

+73
-69
lines changed

5 files changed

+73
-69
lines changed

Cargo.lock

Lines changed: 23 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ test-log = "0.2.18"
6868
thiserror = "2.0.12"
6969
thread-priority = "2.1.0"
7070
tokio = { version = "1.47.1", default-features = false, features = ["macros", "rt", "sync"] }
71-
tracing = { version = "0.1.41", default-features = false, features = ["release_max_level_info", "std"] }
71+
tracing = "0.1.41"
7272
tracing-subscriber = "0.3.19"
7373
walkdir = "2.5.0"
7474
whoami = "1.6.0"

ristretto_cli/src/logging.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pub(crate) fn initialize() -> Result<()> {
1313

1414
let format = tracing_subscriber::fmt::format()
1515
.with_level(true)
16+
.with_target(false)
17+
.with_thread_ids(false)
1618
.with_thread_names(true)
1719
.with_timer(fmt::time::uptime())
1820
.compact();

ristretto_vm/src/call_site_cache.rs

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
//! This module provides a thread-safe cache that tracks the resolution state of invokedynamic call
44
//! sites.
55
6-
use crate::Error::{InternalError, PoisonedLock};
6+
use crate::Error::InternalError;
77
use crate::Result;
8+
use dashmap::DashMap;
89
use ristretto_classloader::Value;
9-
use std::collections::HashMap;
10-
use std::sync::RwLock;
1110

1211
/// Unique identifier for an invokedynamic call site
1312
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -41,14 +40,14 @@ pub enum CallSiteState {
4140
#[derive(Debug)]
4241
pub struct CallSiteCache {
4342
/// Maps call site keys to their resolution states
44-
states: RwLock<HashMap<CallSiteKey, CallSiteState>>,
43+
states: DashMap<CallSiteKey, CallSiteState>,
4544
}
4645

4746
impl CallSiteCache {
4847
/// Create a new empty call site cache
4948
pub fn new() -> Self {
5049
Self {
51-
states: RwLock::new(HashMap::new()),
50+
states: DashMap::new(),
5251
}
5352
}
5453

@@ -83,40 +82,29 @@ impl CallSiteCache {
8382
debug!("CallSiteCache: Checking cache for key: {key:?}");
8483

8584
// Check current state
86-
{
87-
let map = self
88-
.states
89-
.read()
90-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
91-
92-
match map.get(&key) {
93-
Some(CallSiteState::InProgress) => {
94-
debug!("CallSiteCache: RECURSION DETECTED for key: {key:?}",);
85+
match self.states.get(&key) {
86+
Some(ref state) => match &**state {
87+
CallSiteState::InProgress => {
88+
debug!("CallSiteCache: RECURSION DETECTED for key: {key:?}");
9589
return Err(InternalError(format!(
9690
"Recursive invokedynamic call site resolution detected for class '{}' at index {}",
9791
key.class_name, key.instruction_index
9892
)));
9993
}
100-
Some(CallSiteState::Resolved(value)) => {
94+
CallSiteState::Resolved(value) => {
10195
debug!("CallSiteCache: Returning cached result for key: {key:?}");
10296
return Ok(value.clone());
10397
}
104-
None => {
105-
debug!("CallSiteCache: Key not found in cache, will resolve: {key:?}");
106-
// Call site not yet resolved, continue to resolution
107-
}
98+
},
99+
None => {
100+
debug!("CallSiteCache: Key not found in cache, will resolve: {key:?}");
101+
// Call site not yet resolved, continue to resolution
108102
}
109103
}
110104

111105
// Mark as in progress
112-
{
113-
let mut map = self
114-
.states
115-
.write()
116-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
117-
debug!("CallSiteCache: Marking as InProgress: {key:?}");
118-
map.insert(key.clone(), CallSiteState::InProgress);
119-
}
106+
debug!("CallSiteCache: Marking as InProgress: {key:?}");
107+
self.states.insert(key.clone(), CallSiteState::InProgress);
120108

121109
// Perform resolution
122110
debug!("CallSiteCache: Starting resolution for key: {key:?}");
@@ -129,51 +117,32 @@ impl CallSiteCache {
129117
// Update cache based on result
130118
if let Ok(value) = &result {
131119
// Store successful resolution
132-
let mut map = self
133-
.states
134-
.write()
135-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
136120
debug!("CallSiteCache: Caching successful result for key: {key:?}",);
137-
map.insert(key, CallSiteState::Resolved(value.clone()));
121+
self.states
122+
.insert(key, CallSiteState::Resolved(value.clone()));
138123
} else {
139124
// Remove in-progress marker on failure to allow retry
140-
let mut map = self
141-
.states
142-
.write()
143-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
144125
debug!("CallSiteCache: Removing failed resolution from cache for key: {key:?}");
145-
map.remove(&key);
126+
self.states
127+
.remove_if(&key, |_, state| matches!(state, CallSiteState::InProgress));
146128
}
147129

148130
result
149131
}
150132

151133
/// Clear all cached call sites
152-
pub fn clear(&self) -> Result<()> {
153-
let mut map = self
154-
.states
155-
.write()
156-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
157-
map.clear();
158-
Ok(())
134+
pub fn clear(&self) {
135+
self.states.clear();
159136
}
160137

161138
/// Get the number of cached call sites
162-
pub fn len(&self) -> Result<usize> {
163-
let map = self
164-
.states
165-
.read()
166-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
167-
Ok(map.len())
139+
pub fn len(&self) -> usize {
140+
self.states.len()
168141
}
169142

170143
/// Check if the cache is empty
171-
pub fn is_empty(&self) -> Result<bool> {
172-
let map = self
173-
.states
174-
.read()
175-
.map_err(|error| PoisonedLock(format!("Failed to acquire cache lock: {error}")))?;
176-
Ok(map.is_empty())
144+
pub fn is_empty(&self) -> bool {
145+
self.states.is_empty()
177146
}
178147
}
179148

ristretto_vm/src/intrinsic_methods/java/lang/invoke/methodhandlenatives.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use bitflags::bitflags;
88
use ristretto_classfile::VersionSpecification::{
99
Any, Between, GreaterThan, GreaterThanOrEqual, LessThanOrEqual,
1010
};
11-
use ristretto_classfile::{FieldAccessFlags, JAVA_8, JAVA_11, JAVA_17, JAVA_21, MethodAccessFlags};
11+
use ristretto_classfile::{
12+
FieldAccessFlags, JAVA_8, JAVA_11, JAVA_17, JAVA_21, MethodAccessFlags, ReferenceKind,
13+
};
1214
use ristretto_classloader::Error::IllegalAccessError;
1315
use ristretto_classloader::{Class, Method, Value};
1416
use ristretto_macros::intrinsic_method;
@@ -220,6 +222,7 @@ pub(crate) async fn resolve(
220222
class_object.value("name")?.as_string()?
221223
};
222224
let class = thread.class(class_name.clone()).await?;
225+
let _reference_kind = get_reference_kind(flags)?;
223226
let member_name_flags = MemberNameFlags::from_bits_truncate(flags);
224227

225228
// Handle methods/constructors
@@ -322,6 +325,15 @@ pub(crate) async fn resolve(
322325
}
323326
}
324327

328+
/// Extracts the reference kind from the flags of a member name.
329+
fn get_reference_kind(flags: i32) -> Result<ReferenceKind> {
330+
let flags = flags as u32;
331+
let shift = MemberNameFlags::REFERENCE_KIND_SHIFT.bits();
332+
let mask = MemberNameFlags::REFERENCE_KIND_MASK.bits() as u32;
333+
let reference_kind = ((flags >> shift) & mask) as u8;
334+
ReferenceKind::try_from(reference_kind).map_err(Into::into)
335+
}
336+
325337
/// Returns true if `caller` is permitted to access a method of `declaring` with the given access
326338
/// flags.
327339
///
@@ -682,6 +694,15 @@ mod tests {
682694
Ok(())
683695
}
684696

697+
#[test]
698+
fn test_get_reference_kind() -> Result<()> {
699+
assert_eq!(
700+
get_reference_kind(0x0601_0000)?,
701+
ReferenceKind::InvokeStatic
702+
);
703+
Ok(())
704+
}
705+
685706
#[tokio::test]
686707
async fn test_set_call_site_target_normal() -> Result<()> {
687708
let (_vm, thread) = crate::test::thread().await.expect("thread");

0 commit comments

Comments
 (0)