Skip to content

Commit d35d87c

Browse files
committed
fix async tests when component-model-async enabled
Enabling this feature for all tests revealed various missing pieces in the new `concurrent.rs` fiber mechanism, which I've addressed. This adds a bunch of ugly `#[cfg(feature = "component-model-async")]` guards, but those will all go away once I unify the two async fiber implementations. Signed-off-by: Joel Dice <[email protected]>
1 parent c166a9f commit d35d87c

File tree

16 files changed

+595
-218
lines changed

16 files changed

+595
-218
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ rustix = { workspace = true, features = ["mm", "param", "process"] }
8686

8787
[dev-dependencies]
8888
# depend again on wasmtime to activate its default features for tests
89-
wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys', 'signals-based-traps'] }
89+
wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys', 'signals-based-traps', 'component-model-async'] }
9090
env_logger = { workspace = true }
9191
log = { workspace = true }
9292
filecheck = { workspace = true }

benches/call.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ fn bench_host_to_wasm<Params, Results>(
135135
typed_results: Results,
136136
) where
137137
Params: WasmParams + ToVals + Copy,
138-
Results: WasmResults + ToVals + Copy + PartialEq + Debug,
138+
Results: WasmResults + ToVals + Copy + PartialEq + Debug + Sync + 'static,
139139
{
140140
// Benchmark the "typed" version, which should be faster than the versions
141141
// below.

crates/wasmtime/src/runtime/component/concurrent.rs

+113-52
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use {
2626
future::Future,
2727
marker::PhantomData,
2828
mem::{self, MaybeUninit},
29+
ops::Range,
2930
pin::{pin, Pin},
3031
ptr::{self, NonNull},
3132
sync::{Arc, Mutex},
@@ -323,6 +324,23 @@ impl<T: Copy> Drop for Reset<T> {
323324
}
324325
}
325326

327+
#[derive(Clone, Copy)]
328+
struct PollContext {
329+
future_context: *mut Context<'static>,
330+
guard_range_start: *mut u8,
331+
guard_range_end: *mut u8,
332+
}
333+
334+
impl Default for PollContext {
335+
fn default() -> PollContext {
336+
PollContext {
337+
future_context: core::ptr::null_mut(),
338+
guard_range_start: core::ptr::null_mut(),
339+
guard_range_end: core::ptr::null_mut(),
340+
}
341+
}
342+
}
343+
326344
struct AsyncState {
327345
current_suspend: UnsafeCell<
328346
*mut Suspend<
@@ -331,7 +349,7 @@ struct AsyncState {
331349
(Option<*mut dyn VMStore>, Result<()>),
332350
>,
333351
>,
334-
current_poll_cx: UnsafeCell<*mut Context<'static>>,
352+
current_poll_cx: UnsafeCell<PollContext>,
335353
}
336354

337355
unsafe impl Send for AsyncState {}
@@ -344,26 +362,35 @@ pub(crate) struct AsyncCx {
344362
(Option<*mut dyn VMStore>, Result<()>),
345363
>,
346364
current_stack_limit: *mut usize,
347-
current_poll_cx: *mut *mut Context<'static>,
365+
current_poll_cx: *mut PollContext,
348366
track_pkey_context_switch: bool,
349367
}
350368

351369
impl AsyncCx {
352370
pub(crate) fn new<T>(store: &mut StoreContextMut<T>) -> Self {
353-
Self {
354-
current_suspend: store.concurrent_state().async_state.current_suspend.get(),
355-
current_stack_limit: store.0.runtime_limits().stack_limit.get(),
356-
current_poll_cx: store.concurrent_state().async_state.current_poll_cx.get(),
357-
track_pkey_context_switch: store.has_pkey(),
371+
Self::try_new(store).unwrap()
372+
}
373+
374+
pub(crate) fn try_new<T>(store: &mut StoreContextMut<T>) -> Option<Self> {
375+
let current_poll_cx = store.concurrent_state().async_state.current_poll_cx.get();
376+
if unsafe { (*current_poll_cx).future_context.is_null() } {
377+
None
378+
} else {
379+
Some(Self {
380+
current_suspend: store.concurrent_state().async_state.current_suspend.get(),
381+
current_stack_limit: store.0.runtime_limits().stack_limit.get(),
382+
current_poll_cx,
383+
track_pkey_context_switch: store.has_pkey(),
384+
})
358385
}
359386
}
360387

361388
unsafe fn poll<U>(&self, mut future: Pin<&mut (dyn Future<Output = U> + Send)>) -> Poll<U> {
362389
let poll_cx = *self.current_poll_cx;
363390
let _reset = Reset(self.current_poll_cx, poll_cx);
364-
*self.current_poll_cx = ptr::null_mut();
365-
assert!(!poll_cx.is_null());
366-
future.as_mut().poll(&mut *poll_cx)
391+
*self.current_poll_cx = PollContext::default();
392+
assert!(!poll_cx.future_context.is_null());
393+
future.as_mut().poll(&mut *poll_cx.future_context)
367394
}
368395

369396
pub(crate) unsafe fn block_on<'a, T, U>(
@@ -420,6 +447,13 @@ pub struct ConcurrentState<T> {
420447
_phantom: PhantomData<T>,
421448
}
422449

450+
impl<T> ConcurrentState<T> {
451+
pub(crate) fn async_guard_range(&self) -> Range<*mut u8> {
452+
let context = unsafe { *self.async_state.current_poll_cx.get() };
453+
context.guard_range_start..context.guard_range_end
454+
}
455+
}
456+
423457
impl<T> Default for ConcurrentState<T> {
424458
fn default() -> Self {
425459
Self {
@@ -428,7 +462,7 @@ impl<T> Default for ConcurrentState<T> {
428462
futures: ReadyChunks::new(FuturesUnordered::new(), 1024),
429463
async_state: AsyncState {
430464
current_suspend: UnsafeCell::new(ptr::null_mut()),
431-
current_poll_cx: UnsafeCell::new(ptr::null_mut()),
465+
current_poll_cx: UnsafeCell::new(PollContext::default()),
432466
},
433467
instance_states: HashMap::new(),
434468
yielding: HashSet::new(),
@@ -622,7 +656,7 @@ pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>(
622656

623657
pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>(
624658
mut store: StoreContextMut<'a, T>,
625-
instance: RuntimeComponentInstanceIndex,
659+
instance: Option<RuntimeComponentInstanceIndex>,
626660
func: impl FnOnce(&mut StoreContextMut<T>) -> R + Send,
627661
) -> Result<(R, StoreContextMut<'a, T>)> {
628662
let result = Arc::new(Mutex::new(None));
@@ -634,7 +668,21 @@ pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>(
634668
}
635669
})?;
636670

637-
store = poll_fn(store, move |_, mut store| {
671+
let guard_range = fiber
672+
.fiber
673+
.as_ref()
674+
.unwrap()
675+
.stack()
676+
.guard_range()
677+
.map(|r| {
678+
(
679+
NonNull::new(r.start).map(SendSyncPtr::new),
680+
NonNull::new(r.end).map(SendSyncPtr::new),
681+
)
682+
})
683+
.unwrap_or((None, None));
684+
685+
store = poll_fn(store, guard_range, move |_, mut store| {
638686
match resume_fiber(&mut fiber, store.take(), Ok(())) {
639687
Ok(Ok((store, result))) => Ok(result.map(|()| store)),
640688
Ok(Err(s)) => Err(s),
@@ -761,36 +809,40 @@ fn resume_stackful<'a, T>(
761809
match resume_fiber(&mut fiber, Some(store), Ok(()))? {
762810
Ok((mut store, result)) => {
763811
result?;
764-
store = maybe_resume_next_task(store, fiber.instance)?;
765-
for (event, call, _) in mem::take(
766-
&mut store
767-
.concurrent_state()
768-
.table
769-
.get_mut(guest_task)
770-
.with_context(|| format!("bad handle: {}", guest_task.rep()))?
771-
.events,
772-
) {
773-
if event == events::EVENT_CALL_DONE {
774-
log::trace!("resume_stackful will delete call {}", call.rep());
775-
call.delete_all_from(store.as_context_mut())?;
776-
}
777-
}
778-
match &store.concurrent_state().table.get(guest_task)?.caller {
779-
Caller::Host(_) => {
780-
log::trace!("resume_stackful will delete task {}", guest_task.rep());
781-
AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?;
782-
Ok(store)
812+
if let Some(instance) = fiber.instance {
813+
store = maybe_resume_next_task(store, instance)?;
814+
for (event, call, _) in mem::take(
815+
&mut store
816+
.concurrent_state()
817+
.table
818+
.get_mut(guest_task)
819+
.with_context(|| format!("bad handle: {}", guest_task.rep()))?
820+
.events,
821+
) {
822+
if event == events::EVENT_CALL_DONE {
823+
log::trace!("resume_stackful will delete call {}", call.rep());
824+
call.delete_all_from(store.as_context_mut())?;
825+
}
783826
}
784-
Caller::Guest { task, .. } => {
785-
let task = *task;
786-
maybe_send_event(
787-
store,
788-
task,
789-
events::EVENT_CALL_DONE,
790-
AnyTask::Guest(guest_task),
791-
0,
792-
)
827+
match &store.concurrent_state().table.get(guest_task)?.caller {
828+
Caller::Host(_) => {
829+
log::trace!("resume_stackful will delete task {}", guest_task.rep());
830+
AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?;
831+
Ok(store)
832+
}
833+
Caller::Guest { task, .. } => {
834+
let task = *task;
835+
maybe_send_event(
836+
store,
837+
task,
838+
events::EVENT_CALL_DONE,
839+
AnyTask::Guest(guest_task),
840+
0,
841+
)
842+
}
793843
}
844+
} else {
845+
Ok(store)
794846
}
795847
}
796848
Err(new_store) => {
@@ -1029,7 +1081,7 @@ struct StoreFiber<'a> {
10291081
(Option<*mut dyn VMStore>, Result<()>),
10301082
>,
10311083
stack_limit: *mut usize,
1032-
instance: RuntimeComponentInstanceIndex,
1084+
instance: Option<RuntimeComponentInstanceIndex>,
10331085
}
10341086

10351087
impl<'a> Drop for StoreFiber<'a> {
@@ -1054,7 +1106,7 @@ unsafe impl<'a> Sync for StoreFiber<'a> {}
10541106

10551107
fn make_fiber<'a, T>(
10561108
store: &mut StoreContextMut<T>,
1057-
instance: RuntimeComponentInstanceIndex,
1109+
instance: Option<RuntimeComponentInstanceIndex>,
10581110
fun: impl FnOnce(StoreContextMut<T>) -> Result<()> + 'a,
10591111
) -> Result<StoreFiber<'a>> {
10601112
let engine = store.engine().clone();
@@ -1118,9 +1170,11 @@ unsafe fn resume_fiber_raw<'a>(
11181170
fn poll_ready<'a, T>(mut store: StoreContextMut<'a, T>) -> Result<StoreContextMut<'a, T>> {
11191171
unsafe {
11201172
let cx = *store.concurrent_state().async_state.current_poll_cx.get();
1121-
assert!(!cx.is_null());
1122-
while let Poll::Ready(Some(ready)) =
1123-
store.concurrent_state().futures.poll_next_unpin(&mut *cx)
1173+
assert!(!cx.future_context.is_null());
1174+
while let Poll::Ready(Some(ready)) = store
1175+
.concurrent_state()
1176+
.futures
1177+
.poll_next_unpin(&mut *cx.future_context)
11241178
{
11251179
match handle_ready(store, ready) {
11261180
Ok(s) => {
@@ -1691,7 +1745,7 @@ fn do_start_call<'a, T>(
16911745
cx
16921746
}
16931747
} else {
1694-
let mut fiber = make_fiber(&mut cx, callee_instance, move |mut cx| {
1748+
let mut fiber = make_fiber(&mut cx, Some(callee_instance), move |mut cx| {
16951749
if !async_ {
16961750
cx.concurrent_state()
16971751
.instance_states
@@ -2017,12 +2071,12 @@ pub(crate) async fn poll_until<'a, T: Send, U>(
20172071
.await;
20182072

20192073
if ready.is_some() {
2020-
store = poll_fn(store, move |_, mut store| {
2074+
store = poll_fn(store, (None, None), move |_, mut store| {
20212075
Ok(handle_ready(store.take().unwrap(), ready.take().unwrap()))
20222076
})
20232077
.await?;
20242078
} else {
2025-
let (s, resumed) = poll_fn(store, move |_, mut store| {
2079+
let (s, resumed) = poll_fn(store, (None, None), move |_, mut store| {
20262080
Ok(unyield(store.take().unwrap()))
20272081
})
20282082
.await?;
@@ -2039,7 +2093,7 @@ pub(crate) async fn poll_until<'a, T: Send, U>(
20392093
Either::Left((None, future_again)) => break Ok((store, future_again.await)),
20402094
Either::Left((Some(ready), future_again)) => {
20412095
let mut ready = Some(ready);
2042-
store = poll_fn(store, move |_, mut store| {
2096+
store = poll_fn(store, (None, None), move |_, mut store| {
20432097
Ok(handle_ready(store.take().unwrap(), ready.take().unwrap()))
20442098
})
20452099
.await?;
@@ -2052,13 +2106,14 @@ pub(crate) async fn poll_until<'a, T: Send, U>(
20522106

20532107
async fn poll_fn<'a, T, R>(
20542108
mut store: StoreContextMut<'a, T>,
2109+
guard_range: (Option<SendSyncPtr<u8>>, Option<SendSyncPtr<u8>>),
20552110
mut fun: impl FnMut(
20562111
&mut Context,
20572112
Option<StoreContextMut<'a, T>>,
20582113
) -> Result<R, Option<StoreContextMut<'a, T>>>,
20592114
) -> R {
20602115
#[derive(Clone, Copy)]
2061-
struct PollCx(*mut *mut Context<'static>);
2116+
struct PollCx(*mut PollContext);
20622117

20632118
unsafe impl Send for PollCx {}
20642119

@@ -2068,7 +2123,13 @@ async fn poll_fn<'a, T, R>(
20682123

20692124
move |cx| unsafe {
20702125
let _reset = Reset(poll_cx.0, *poll_cx.0);
2071-
*poll_cx.0 = mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx);
2126+
let guard_range_start = guard_range.0.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut());
2127+
let guard_range_end = guard_range.1.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut());
2128+
*poll_cx.0 = PollContext {
2129+
future_context: mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx),
2130+
guard_range_start,
2131+
guard_range_end,
2132+
};
20722133
#[allow(dropping_copy_types)]
20732134
drop(poll_cx);
20742135

crates/wasmtime/src/runtime/component/func.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ impl Func {
332332
let instance = store.0[self.0].component_instance;
333333
// TODO: do we need to return the store here due to the possible
334334
// invalidation of the reference we were passed?
335-
concurrent::on_fiber(store, instance, move |store| {
335+
concurrent::on_fiber(store, Some(instance), move |store| {
336336
self.call_impl(store, params, results)
337337
})
338338
.await?
@@ -367,7 +367,7 @@ impl Func {
367367
let instance = store.0[self.0].component_instance;
368368
// TODO: do we need to return the store here due to the possible
369369
// invalidation of the reference we were passed?
370-
concurrent::on_fiber(store, instance, move |store| {
370+
concurrent::on_fiber(store, Some(instance), move |store| {
371371
self.start_call(store.as_context_mut(), params)
372372
})
373373
.await?

crates/wasmtime/src/runtime/component/func/typed.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ where
199199
let instance = store.0[self.func.0].component_instance;
200200
// TODO: do we need to return the store here due to the possible
201201
// invalidation of the reference we were passed?
202-
concurrent::on_fiber(store, instance, move |store| self.call_impl(store, params))
203-
.await?
204-
.0
202+
concurrent::on_fiber(store, Some(instance), move |store| {
203+
self.call_impl(store, params)
204+
})
205+
.await?
206+
.0
205207
}
206208
#[cfg(not(feature = "component-model-async"))]
207209
{
@@ -236,7 +238,7 @@ where
236238
let instance = store.0[self.func.0].component_instance;
237239
// TODO: do we need to return the store here due to the possible
238240
// invalidation of the reference we were passed?
239-
concurrent::on_fiber(store, instance, move |store| {
241+
concurrent::on_fiber(store, Some(instance), move |store| {
240242
self.start_call(store.as_context_mut(), params)
241243
})
242244
.await?

crates/wasmtime/src/runtime/component/instance.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -864,12 +864,24 @@ impl<T> InstancePre<T> {
864864
where
865865
T: Send + 'static,
866866
{
867-
let mut store = store.as_context_mut();
867+
let store = store.as_context_mut();
868868
assert!(
869869
store.0.async_support(),
870870
"must use sync instantiation when async support is disabled"
871871
);
872-
store.on_fiber(|store| self.instantiate_impl(store)).await?
872+
#[cfg(feature = "component-model-async")]
873+
{
874+
// TODO: do we need to return the store here due to the possible
875+
// invalidation of the reference we were passed?
876+
concurrent::on_fiber(store, None, move |store| self.instantiate_impl(store))
877+
.await?
878+
.0
879+
}
880+
#[cfg(not(feature = "component-model-async"))]
881+
{
882+
let mut store = store;
883+
store.on_fiber(|store| self.instantiate_impl(store)).await?
884+
}
873885
}
874886

875887
fn instantiate_impl(&self, mut store: impl AsContextMut<Data = T>) -> Result<Instance>

0 commit comments

Comments
 (0)