Skip to content

Commit a051385

Browse files
Merge pull request #548 from theseus-rs/add-member-handles
feat: add member handles
2 parents 10c84cf + 5a8ba82 commit a051385

File tree

4 files changed

+217
-84
lines changed

4 files changed

+217
-84
lines changed

ristretto_vm/src/handles/member.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use ristretto_classloader::Method;
2+
use std::sync::Arc;
3+
4+
/// Represents a handle to a member in the Java Virtual Machine (JVM). This is used to dynamically
5+
/// invoke methods or access fields in a class.
6+
#[derive(Debug)]
7+
pub(crate) struct MemberHandle {
8+
pub(crate) method: Option<Arc<Method>>,
9+
pub(crate) field: Option<usize>,
10+
}
11+
12+
impl From<Arc<Method>> for MemberHandle {
13+
fn from(method: Arc<Method>) -> Self {
14+
MemberHandle {
15+
method: Some(method),
16+
field: None,
17+
}
18+
}
19+
}
20+
21+
impl From<usize> for MemberHandle {
22+
fn from(field: usize) -> Self {
23+
MemberHandle {
24+
method: None,
25+
field: Some(field),
26+
}
27+
}
28+
}
29+
30+
#[cfg(test)]
31+
mod tests {
32+
use super::*;
33+
use crate::Result;
34+
35+
#[tokio::test]
36+
async fn test_member_handle_from_method() -> Result<()> {
37+
let (_vm, thread) = crate::test::thread().await.expect("thread");
38+
let class = thread.class("java.lang.Object").await?;
39+
let method = class.try_get_method("hashCode", "()I")?;
40+
let member_handle: MemberHandle = method.into();
41+
assert_eq!(member_handle.method.expect("method").name(), "hashCode");
42+
assert!(member_handle.field.is_none());
43+
Ok(())
44+
}
45+
46+
#[tokio::test]
47+
async fn test_member_handle_from_field() -> Result<()> {
48+
let (_vm, thread) = crate::test::thread().await.expect("thread");
49+
let class = thread.class("java.lang.Integer").await?;
50+
let field = class.field_offset("serialVersionUID")?;
51+
let method_handle: MemberHandle = field.into();
52+
assert!(method_handle.method.is_none());
53+
assert_eq!(method_handle.field.expect("name"), field);
54+
Ok(())
55+
}
56+
}

ristretto_vm/src/handles/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
66
mod file;
77
mod manager;
8+
mod member;
89
mod thread;
910

1011
pub(crate) use file::{FileHandle, FileModeFlags};
1112
pub(crate) use manager::HandleManager;
13+
pub(crate) use member::MemberHandle;
1214
pub(crate) use thread::ThreadHandle;

ristretto_vm/src/intrinsic_methods/java/lang/invoke/methodhandlenatives.rs

Lines changed: 151 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::Error::InternalError;
2-
use crate::Result;
32
use crate::intrinsic_methods::java::lang::class::get_class;
43
use crate::parameters::Parameters;
54
use crate::thread::Thread;
5+
use crate::{JavaObject, Result};
66
use async_recursion::async_recursion;
77
use bitflags::bitflags;
88
use ristretto_classfile::VersionSpecification::{
@@ -202,7 +202,6 @@ pub(crate) async fn register_natives(
202202
Ok(None)
203203
}
204204

205-
#[expect(clippy::too_many_lines)]
206205
pub(crate) async fn resolve(
207206
thread: Arc<Thread>,
208207
member_self: Value,
@@ -221,18 +220,62 @@ pub(crate) async fn resolve(
221220
let class_object = class_object.as_object_ref()?;
222221
class_object.value("name")?.as_string()?
223222
};
224-
let class = thread.class(class_name.clone()).await?;
225-
let _reference_kind = get_reference_kind(flags)?;
223+
let class = thread.class(class_name).await?;
226224
let member_name_flags = MemberNameFlags::from_bits_truncate(flags);
227225

228-
// Handle methods/constructors
229226
if member_name_flags.contains(MemberNameFlags::IS_METHOD)
230227
|| member_name_flags.contains(MemberNameFlags::IS_CONSTRUCTOR)
231228
{
232-
let method_type = {
233-
let member_self = member_self.as_object_ref()?;
234-
member_self.value("type")?
235-
};
229+
resolve_method(
230+
&thread,
231+
member_self,
232+
&caller,
233+
lookup_mode_flags,
234+
speculative_resolve,
235+
&name,
236+
flags,
237+
&class,
238+
)
239+
.await
240+
} else if member_name_flags.contains(MemberNameFlags::IS_FIELD) {
241+
resolve_field(
242+
&thread,
243+
member_self,
244+
caller,
245+
lookup_mode_flags,
246+
speculative_resolve,
247+
name,
248+
flags,
249+
&class,
250+
)
251+
.await
252+
} else {
253+
Err(InternalError(format!(
254+
"Unsupported member name flag: {member_name_flags:?}"
255+
)))
256+
}
257+
}
258+
259+
/// Resolves a method in the given class, checking access permissions and returning the member self
260+
/// if successful.
261+
#[expect(clippy::too_many_arguments)]
262+
async fn resolve_method(
263+
thread: &Thread,
264+
member_self: Value,
265+
caller: &Option<Arc<Class>>,
266+
lookup_mode_flags: &LookupModeFlags,
267+
speculative_resolve: bool,
268+
name: &Value,
269+
flags: i32,
270+
class: &Arc<Class>,
271+
) -> Result<Option<Value>> {
272+
let _reference_kind = get_reference_kind(flags)?;
273+
let method_type = {
274+
let member_self = member_self.as_object_ref()?;
275+
member_self.value("type")?
276+
};
277+
278+
let (parameter_descriptors, return_descriptor) = {
236279
let method_type = method_type.as_object_ref()?;
237280
let parameter_types = method_type.value("ptypes")?;
238281
let parameters: Vec<Value> = parameter_types.try_into()?;
@@ -247,82 +290,108 @@ pub(crate) async fn resolve(
247290
let return_type = return_type.as_object_ref()?;
248291
let return_class_name = return_type.value("name")?.as_string()?;
249292
let return_descriptor = Class::convert_to_descriptor(&return_class_name);
293+
(parameter_descriptors, return_descriptor)
294+
};
250295

251-
let method_name = name.as_string()?;
252-
let method_descriptor = format!("({}){return_descriptor}", parameter_descriptors.concat());
253-
let method = match class_name.as_str() {
254-
"java.lang.invoke.DelegatingMethodHandle$Holder"
255-
| "java.lang.invoke.DirectMethodHandle$Holder"
256-
| "java.lang.invoke.Invokers$Holder" => {
257-
resolve_holder_methods(class.clone(), &method_name, &method_descriptor)?
258-
}
259-
_ => class.try_get_method(method_name.clone(), method_descriptor.clone())?,
296+
let method_name = name.as_string()?;
297+
let method_descriptor = format!("({}){return_descriptor}", parameter_descriptors.concat());
298+
let method = match class.name() {
299+
"java.lang.invoke.DelegatingMethodHandle$Holder"
300+
| "java.lang.invoke.DirectMethodHandle$Holder"
301+
| "java.lang.invoke.Invokers$Holder" => {
302+
resolve_holder_methods(class.clone(), &method_name, &method_descriptor)?
303+
}
304+
_ => class.try_get_method(&method_name, &method_descriptor)?,
305+
};
306+
307+
// Access control enforcement
308+
let method_access_flags = method.access_flags();
309+
if !check_method_access(caller, class, *method_access_flags, *lookup_mode_flags)? {
310+
return if speculative_resolve {
311+
// If speculative, return None (fail silently)
312+
Ok(None)
313+
} else {
314+
Err(IllegalAccessError(format!(
315+
"member is {}: {}.{method_name}{method_descriptor}",
316+
if method_access_flags.contains(MethodAccessFlags::PRIVATE) {
317+
"private"
318+
} else {
319+
"inaccessible"
320+
},
321+
class.name(),
322+
))
323+
.into())
260324
};
325+
}
261326

262-
// Access control enforcement
263-
let method_access_flags = method.access_flags();
264-
if !check_method_access(caller, &class, *method_access_flags, *lookup_mode_flags)? {
265-
return if speculative_resolve {
266-
// If speculative, return None (fail silently)
267-
Ok(None)
268-
} else {
269-
Err(IllegalAccessError(format!(
270-
"member is {}: {}.{}{}",
271-
if method_access_flags.contains(MethodAccessFlags::PRIVATE) {
272-
"private"
273-
} else {
274-
"inaccessible"
275-
},
276-
class_name,
277-
method_name,
278-
method_descriptor,
279-
))
280-
.into())
281-
};
282-
}
327+
let modifiers = i32::from(method_access_flags.bits());
328+
let flags = flags | modifiers;
329+
{
330+
let vm = thread.vm()?;
331+
let member_handles = vm.member_handles();
332+
let method_signature =
333+
format!("{}.{}{}", class.name(), method.name(), method.descriptor(),);
334+
member_handles
335+
.insert(method_signature, method.into())
336+
.await?;
337+
let _vmindex = method_descriptor.to_object(thread).await?;
338+
let mut member_self = member_self.as_object_mut()?;
339+
member_self.set_value("flags", Value::from(flags))?;
340+
// member_self.set_value("vmindex", vmindex)?;
341+
}
342+
Ok(Some(member_self))
343+
}
283344

284-
let modifiers = i32::from(method_access_flags.bits());
285-
let flags = flags | modifiers;
286-
{
287-
let mut member_self = member_self.as_object_mut()?;
288-
member_self.set_value("flags", Value::from(flags))?;
289-
}
290-
Ok(Some(member_self))
291-
}
292-
// Handle fields (for both normal field and VarHandle)
293-
else if member_name_flags.contains(MemberNameFlags::IS_FIELD) {
294-
let field_name = name.as_string()?;
295-
let field = class.declared_field(&field_name)?;
296-
let field_access_flags = field.access_flags();
297-
if !check_field_access(caller, &class, *field_access_flags, *lookup_mode_flags)? {
298-
return if speculative_resolve {
299-
Ok(None)
300-
} else {
301-
Err(IllegalAccessError(format!(
302-
"member is {}: {}.{}",
303-
if field_access_flags.contains(FieldAccessFlags::PRIVATE) {
304-
"private"
305-
} else {
306-
"inaccessible"
307-
},
308-
class_name,
309-
field_name,
310-
))
311-
.into())
312-
};
313-
}
314-
let modifiers = i32::from(field_access_flags.bits());
315-
let flags = flags | modifiers;
316-
{
317-
let mut member_self = member_self.as_object_mut()?;
318-
member_self.set_value("flags", Value::from(flags))?;
319-
}
320-
Ok(Some(member_self))
321-
} else {
322-
Err(InternalError(format!(
323-
"Unsupported member name flag: {member_name_flags:?}"
324-
)))
345+
/// Resolves a field in the given class, checking access permissions and returning the member self
346+
/// if successful.
347+
#[expect(clippy::too_many_arguments)]
348+
async fn resolve_field(
349+
thread: &Thread,
350+
member_self: Value,
351+
caller: Option<Arc<Class>>,
352+
lookup_mode_flags: &LookupModeFlags,
353+
speculative_resolve: bool,
354+
name: Value,
355+
flags: i32,
356+
class: &Arc<Class>,
357+
) -> Result<Option<Value>> {
358+
let _reference_kind = get_reference_kind(flags)?;
359+
let field_name = name.as_string()?;
360+
let field = class.declared_field(&field_name)?;
361+
let field_access_flags = field.access_flags();
362+
if !check_field_access(caller, class, *field_access_flags, *lookup_mode_flags)? {
363+
return if speculative_resolve {
364+
Ok(None)
365+
} else {
366+
Err(IllegalAccessError(format!(
367+
"member is {}: {}.{}",
368+
if field_access_flags.contains(FieldAccessFlags::PRIVATE) {
369+
"private"
370+
} else {
371+
"inaccessible"
372+
},
373+
class.name(),
374+
field_name,
375+
))
376+
.into())
377+
};
325378
}
379+
let modifiers = i32::from(field_access_flags.bits());
380+
let flags = flags | modifiers;
381+
{
382+
let vm = thread.vm()?;
383+
let member_handles = vm.member_handles();
384+
let field_offset = class.field_offset(&field_name)?;
385+
let field_signature = format!("{}.{field_name}", class.name(),);
386+
member_handles
387+
.insert(field_signature.clone(), field_offset.into())
388+
.await?;
389+
let _vmindex = field_signature.to_object(thread).await?;
390+
let mut member_self = member_self.as_object_mut()?;
391+
member_self.set_value("flags", Value::from(flags))?;
392+
// member_self.set_value("vmindex", vmindex)?;
393+
}
394+
Ok(Some(member_self))
326395
}
327396

328397
/// Extracts the reference kind from the flags of a member name.
@@ -340,9 +409,8 @@ fn get_reference_kind(flags: i32) -> Result<ReferenceKind> {
340409
/// # References
341410
///
342411
/// - [JLS §6.6 Access Control](https://docs.oracle.com/javase/specs/jls/se24/html/jls-6.html#jls-6.6)
343-
#[expect(clippy::needless_pass_by_value)]
344412
pub fn check_method_access(
345-
caller: Option<Arc<Class>>,
413+
caller: &Option<Arc<Class>>,
346414
declaring: &Arc<Class>,
347415
method_access_flags: MethodAccessFlags,
348416
lookup_mode_flags: LookupModeFlags,
@@ -356,7 +424,7 @@ pub fn check_method_access(
356424
return Ok(true);
357425
}
358426

359-
let Some(ref caller) = caller else {
427+
let Some(caller) = caller else {
360428
return Ok(false);
361429
};
362430

ristretto_vm/src/vm.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::Error::InternalError;
22
use crate::call_site_cache::CallSiteCache;
3-
use crate::handles::{FileHandle, HandleManager, ThreadHandle};
3+
use crate::handles::{FileHandle, HandleManager, MemberHandle, ThreadHandle};
44
use crate::intrinsic_methods::MethodRegistry;
55
use crate::java_object::JavaObject;
66
use crate::rust_value::RustValue;
@@ -44,6 +44,7 @@ pub struct VM {
4444
next_thread_id: AtomicU64,
4545
thread_handles: HandleManager<u64, ThreadHandle>,
4646
file_handles: HandleManager<String, FileHandle>,
47+
member_handles: HandleManager<String, MemberHandle>,
4748
string_pool: StringPool,
4849
call_site_cache: CallSiteCache,
4950
}
@@ -159,6 +160,7 @@ impl VM {
159160
next_thread_id: AtomicU64::new(1),
160161
thread_handles: HandleManager::new(),
161162
file_handles: HandleManager::new(),
163+
member_handles: HandleManager::new(),
162164
string_pool: StringPool::new(),
163165
call_site_cache: CallSiteCache::new(),
164166
});
@@ -273,6 +275,11 @@ impl VM {
273275
&self.file_handles
274276
}
275277

278+
/// Get the VM member handles used for dynamic invocation
279+
pub(crate) fn member_handles(&self) -> &HandleManager<String, MemberHandle> {
280+
&self.member_handles
281+
}
282+
276283
/// Initialize the VM
277284
///
278285
/// # Errors

0 commit comments

Comments
 (0)