Skip to content

Commit 7cf1ed3

Browse files
committed
finish implementing reentrance checks
Signed-off-by: Joel Dice <[email protected]>
1 parent 52a449e commit 7cf1ed3

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

crates/wasmtime/src/runtime/component/concurrent.rs

+38
Original file line numberDiff line numberDiff line change
@@ -1740,9 +1740,41 @@ pub(crate) extern "C" fn async_enter<T>(
17401740
}
17411741
}
17421742

1743+
fn may_enter<T>(
1744+
store: &mut StoreContextMut<T>,
1745+
mut guest_task: TableId<GuestTask>,
1746+
guest_instance: RuntimeComponentInstanceIndex,
1747+
) -> bool {
1748+
// Walk the task tree back to the root, looking for potential reentrance.
1749+
//
1750+
// TODO: This could be optimized by maintaining a per-`GuestTask` bitset
1751+
// such that each bit represents and instance which has been entered by that
1752+
// task or an ancestor of that task, in which case this would be a constant
1753+
// time check.
1754+
loop {
1755+
match &store
1756+
.concurrent_state()
1757+
.table
1758+
.get_mut(guest_task)
1759+
.unwrap()
1760+
.caller
1761+
{
1762+
Caller::Host(_) => break true,
1763+
Caller::Guest { task, instance } => {
1764+
if *instance == guest_instance {
1765+
break false;
1766+
} else {
1767+
guest_task = *task;
1768+
}
1769+
}
1770+
}
1771+
}
1772+
}
1773+
17431774
fn make_call<T>(
17441775
guest_task: TableId<GuestTask>,
17451776
callee: SendSyncPtr<VMFuncRef>,
1777+
callee_instance: RuntimeComponentInstanceIndex,
17461778
param_count: usize,
17471779
result_count: usize,
17481780
flags: Option<InstanceFlags>,
@@ -1753,6 +1785,10 @@ fn make_call<T>(
17531785
+ Sync
17541786
+ 'static {
17551787
move |mut cx: StoreContextMut<T>| {
1788+
if !may_enter(&mut cx, guest_task, callee_instance) {
1789+
bail!(crate::Trap::CannotEnterComponent);
1790+
}
1791+
17561792
let mut storage = [MaybeUninit::uninit(); MAX_FLAT_PARAMS];
17571793
let lower = cx
17581794
.concurrent_state()
@@ -2021,6 +2057,7 @@ pub(crate) extern "C" fn async_exit<T>(
20212057
let call = make_call(
20222058
guest_task,
20232059
callee,
2060+
callee_instance,
20242061
param_count,
20252062
result_count,
20262063
if callback.is_null() {
@@ -2123,6 +2160,7 @@ pub(crate) fn start_call<'a, T: Send, LowerParams: Copy, R: 'static>(
21232160
let call = make_call(
21242161
guest_task,
21252162
SendSyncPtr::new(callee),
2163+
component_instance,
21262164
mem::size_of::<LowerParams>() / mem::size_of::<ValRaw>(),
21272165
1,
21282166
if callback.is_none() {

tests/all/component_model/func.rs

+113
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,119 @@ fn strings() -> Result<()> {
821821
Ok(())
822822
}
823823

824+
#[tokio::test]
825+
async fn async_reentrance() -> Result<()> {
826+
let component = r#"
827+
(component
828+
(core module $shim
829+
(import "" "task.return" (func $task-return (param i32)))
830+
(table (export "funcs") 1 1 funcref)
831+
(func (export "export") (param i32) (result i32)
832+
(call_indirect (i32.const 0) (local.get 0))
833+
)
834+
(func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable)
835+
)
836+
(core type $task-return-type (func (param i32)))
837+
(core func $task-return (canon task.return $task-return-type))
838+
(core instance $shim (instantiate $shim
839+
(with "" (instance (export "task.return" (func $task-return))))
840+
))
841+
(func $shim-export (param "p1" u32) (result u32)
842+
(canon lift (core func $shim "export") async (callback (func $shim "callback")))
843+
)
844+
845+
(component $inner
846+
(import "import" (func $import (param "p1" u32) (result u32)))
847+
(core module $libc (memory (export "memory") 1))
848+
(core instance $libc (instantiate $libc))
849+
(core func $import (canon lower (func $import) async (memory $libc "memory")))
850+
851+
(core module $m
852+
(import "libc" "memory" (memory 1))
853+
(import "" "import" (func $import (param i32 i32) (result i32)))
854+
(import "" "task.return" (func $task-return (param i32)))
855+
(func (export "export") (param i32) (result i32)
856+
(i32.store offset=0 (i32.const 1200) (local.get 0))
857+
(call $import (i32.const 1200) (i32.const 1204))
858+
drop
859+
(call $task-return (i32.load offset=0 (i32.const 1204)))
860+
i32.const 0
861+
)
862+
(func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable)
863+
)
864+
(core type $task-return-type (func (param i32)))
865+
(core func $task-return (canon task.return $task-return-type))
866+
(core instance $i (instantiate $m
867+
(with "" (instance
868+
(export "task.return" (func $task-return))
869+
(export "import" (func $import))
870+
))
871+
(with "libc" (instance $libc))
872+
))
873+
(func (export "export") (param "p1" u32) (result u32)
874+
(canon lift (core func $i "export") async (callback (func $i "callback")))
875+
)
876+
)
877+
(instance $inner (instantiate $inner (with "import" (func $shim-export))))
878+
879+
(core module $libc (memory (export "memory") 1))
880+
(core instance $libc (instantiate $libc))
881+
(core func $inner-export (canon lower (func $inner "export") async (memory $libc "memory")))
882+
883+
(core module $donut
884+
(import "" "funcs" (table 1 1 funcref))
885+
(import "libc" "memory" (memory 1))
886+
(import "" "import" (func $import (param i32 i32) (result i32)))
887+
(import "" "task.return" (func $task-return (param i32)))
888+
(func $host-export (export "export") (param i32) (result i32)
889+
(i32.store offset=0 (i32.const 1200) (local.get 0))
890+
(call $import (i32.const 1200) (i32.const 1204))
891+
drop
892+
(call $task-return (i32.load offset=0 (i32.const 1204)))
893+
i32.const 0
894+
)
895+
(func $guest-export (export "guest-export") (param i32) (result i32) unreachable)
896+
(func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable)
897+
(func $start
898+
(table.set (i32.const 0) (ref.func $guest-export))
899+
)
900+
(start $start)
901+
)
902+
903+
(core instance $donut (instantiate $donut
904+
(with "" (instance
905+
(export "task.return" (func $task-return))
906+
(export "import" (func $inner-export))
907+
(export "funcs" (table $shim "funcs"))
908+
))
909+
(with "libc" (instance $libc))
910+
))
911+
(func (export "export") (param "p1" u32) (result u32)
912+
(canon lift (core func $donut "export") async (callback (func $donut "callback")))
913+
)
914+
)"#;
915+
916+
let mut config = Config::new();
917+
config.wasm_component_model_async(true);
918+
config.async_support(true);
919+
let engine = &Engine::new(&config)?;
920+
let component = Component::new(&engine, component)?;
921+
let mut store = Store::new(&engine, ());
922+
923+
let instance = Linker::new(&engine)
924+
.instantiate_async(&mut store, &component)
925+
.await?;
926+
927+
let func = instance.get_typed_func::<(u32,), (u32,)>(&mut store, "export")?;
928+
929+
match func.call_concurrent(&mut store, (42,)).await {
930+
Ok(_) => panic!(),
931+
Err(e) => assert!(format!("{e:?}").contains("cannot enter component instance")),
932+
}
933+
934+
Ok(())
935+
}
936+
824937
#[tokio::test]
825938
async fn missing_task_return_call_stackless() -> Result<()> {
826939
test_missing_task_return_call(r#"(component

0 commit comments

Comments
 (0)