diff --git a/crates/wassette/src/lib.rs b/crates/wassette/src/lib.rs index 7628157a..468f99c0 100644 --- a/crates/wassette/src/lib.rs +++ b/crates/wassette/src/lib.rs @@ -9,7 +9,7 @@ use component2json::{ component_exports_to_json_schema, component_exports_to_tools, create_placeholder_results, json_to_vals, vals_to_json, FunctionIdentifier, ToolMetadata, }; -use futures::TryStreamExt; +use futures::{future, TryStreamExt}; use policy::{ AccessType, EnvironmentPermission, NetworkHostPermission, NetworkPermission, PolicyParser, StoragePermission, @@ -543,52 +543,108 @@ impl LifecycleManager { let mut components = HashMap::new(); let mut policy_registry = PolicyRegistry::default(); - let loaded_components = + // Scan for component files + let scanned_components = tokio_stream::wrappers::ReadDirStream::new(tokio::fs::read_dir(&plugin_dir).await?) .map_err(anyhow::Error::from) - .try_filter_map(|entry| { - let value = engine.clone(); - async move { load_component_from_entry(value, entry).await } - }) + .try_filter_map(|entry| async move { scan_component_from_entry(entry).await }) .try_collect::>() .await?; - for (component, name) in loaded_components.into_iter() { - let tool_metadata = component_exports_to_tools(&component, &engine, true); - registry - .register_tools(&name, tool_metadata) - .context("unable to insert component into registry")?; - components.insert(name.clone(), Arc::new(component)); - - // Check for co-located policy file and restore policy association - let policy_path = plugin_dir.as_ref().join(format!("{name}.policy.yaml")); - if policy_path.exists() { - match tokio::fs::read_to_string(&policy_path).await { - Ok(policy_content) => match PolicyParser::parse_str(&policy_content) { - Ok(policy) => { - match wasistate::create_wasi_state_template_from_policy( - &policy, - plugin_dir.as_ref(), - ) { - Ok(wasi_template) => { - policy_registry - .component_policies - .insert(name.clone(), Arc::new(wasi_template)); - info!(component_id = %name, "Restored policy association from co-located file"); - } + info!( + "Found {} components to load in parallel", + scanned_components.len() + ); + + // Load all components in parallel for faster startup with parallelization + let component_loading_tasks = scanned_components + .into_iter() + .map(|(component_path, component_id)| { + let engine = engine.clone(); + let plugin_dir = plugin_dir.as_ref().to_path_buf(); + async move { + let start_time = Instant::now(); + + // Load and compile the component in a blocking task + let component = { + let engine = engine.clone(); + let path = component_path.clone(); + tokio::task::spawn_blocking(move || Component::from_file(&engine, path)).await?? + }; + + // Generate tool metadata + let tool_metadata = component_exports_to_tools(&component, &engine, true); + + // Load co-located policy if it exists + let policy_template = { + let policy_path = plugin_dir.join(format!("{component_id}.policy.yaml")); + if policy_path.exists() { + match tokio::fs::read_to_string(&policy_path).await { + Ok(policy_content) => match PolicyParser::parse_str(&policy_content) { + Ok(policy) => { + match wasistate::create_wasi_state_template_from_policy( + &policy, + &plugin_dir, + ) { + Ok(wasi_template) => { + info!(component_id = %component_id, "Loaded policy association from co-located file"); + Some(Arc::new(wasi_template)) + } + Err(e) => { + warn!(component_id = %component_id, error = %e, "Failed to create WASI template from policy"); + None + } + } + } + Err(e) => { + warn!(component_id = %component_id, error = %e, "Failed to parse co-located policy file"); + None + } + }, Err(e) => { - warn!(component_id = %name, error = %e, "Failed to create WASI template from policy"); + warn!(component_id = %component_id, error = %e, "Failed to read co-located policy file"); + None } } + } else { + None } - Err(e) => { - warn!(component_id = %name, error = %e, "Failed to parse co-located policy file"); - } - }, - Err(e) => { - warn!(component_id = %name, error = %e, "Failed to read co-located policy file"); - } + }; + + info!(component_id = %component_id, elapsed = ?start_time.elapsed(), "Component loaded successfully"); + + Ok::<_, anyhow::Error>((component_id, Arc::new(component), tool_metadata, policy_template)) + } + }); + + // Await all component loading tasks in parallel, filtering out failed components + let component_results = future::join_all(component_loading_tasks).await; + let loaded_components: Vec<_> = component_results + .into_iter() + .filter_map(|result| match result { + Ok((component_id, component, tool_metadata, policy_template)) => { + Some((component_id, component, tool_metadata, policy_template)) } + Err(e) => { + warn!(error = %e, "Failed to load component, skipping"); + None + } + }) + .collect(); + + // Now register all loaded components + for (component_id, component, tool_metadata, policy_template) in loaded_components { + registry + .register_tools(&component_id, tool_metadata) + .with_context(|| { + format!("unable to register tools for component {component_id}") + })?; + components.insert(component_id.clone(), component); + + if let Some(policy_template) = policy_template { + policy_registry + .component_policies + .insert(component_id, policy_template); } } @@ -600,7 +656,10 @@ impl LifecycleManager { .await .context("Failed to create downloads directory")?; - info!("LifecycleManager initialized successfully"); + info!( + "LifecycleManager initialized successfully with {} components loaded in parallel", + components.len() + ); Ok(Self { engine, components: Arc::new(RwLock::new(components)), @@ -675,6 +734,7 @@ impl LifecycleManager { pub async fn uninstall_component(&self, id: &str) -> Result<()> { debug!("Uninstalling component"); self.unload_component(id).await; + let component_file = self.component_path(id); tokio::fs::remove_file(&component_file) .await @@ -689,40 +749,43 @@ impl LifecycleManager { #[instrument(skip(self))] pub async fn get_component_id_for_tool(&self, tool_name: &str) -> Result { let registry = self.registry.read().await; - let tool_infos = registry - .get_tool_info(tool_name) - .context("Tool not found")?; - - if tool_infos.len() > 1 { - bail!( - "Multiple components found for tool '{}': {}", - tool_name, - tool_infos - .iter() - .map(|info| info.component_id.as_str()) - .collect::>() - .join(", ") - ); + if let Some(tool_infos) = registry.get_tool_info(tool_name) { + if tool_infos.len() > 1 { + bail!( + "Multiple components found for tool '{}': {}", + tool_name, + tool_infos + .iter() + .map(|info| info.component_id.as_str()) + .collect::>() + .join(", ") + ); + } + Ok(tool_infos[0].component_id.clone()) + } else { + bail!("Tool not found: {}", tool_name); } - - Ok(tool_infos[0].component_id.clone()) } /// Lists all available tools across all components #[instrument(skip(self))] pub async fn list_tools(&self) -> Vec { + // All components are loaded at startup with parallel loading self.registry.read().await.list_tools() } /// Returns the requested component. Returns `None` if the component is not found. #[instrument(skip(self))] pub async fn get_component(&self, component_id: &str) -> Option> { + // All components are loaded at startup with parallel loading self.components.read().await.get(component_id).cloned() } #[instrument(skip(self))] pub async fn list_components(&self) -> Vec { - self.components.read().await.keys().cloned().collect() + // All components are loaded at startup with parallel loading + let components = self.components.read().await; + components.keys().cloned().collect() } /// Gets the schema for a specific component @@ -1236,11 +1299,7 @@ impl LifecycleManager { } } -async fn load_component_from_entry( - engine: Arc, - entry: DirEntry, -) -> Result> { - let start_time = Instant::now(); +async fn scan_component_from_entry(entry: DirEntry) -> Result> { let is_file = entry .metadata() .await @@ -1255,16 +1314,12 @@ async fn load_component_from_entry( return Ok(None); } let entry_path = entry.path(); - let component = - tokio::task::spawn_blocking(move || Component::from_file(&engine, entry_path)).await??; - let name = entry - .path() + let name = entry_path .file_stem() .and_then(|s| s.to_str()) .map(String::from) .context("wasm file didn't have a valid file name")?; - info!(component_id = %name, elapsed = ?start_time.elapsed(), "component loaded"); - Ok(Some((component, name))) + Ok(Some((entry_path, name))) } #[cfg(test)] @@ -1368,6 +1423,76 @@ mod tests { Ok(()) } + #[test(tokio::test)] + async fn test_parallel_loading_performance() -> Result<()> { + let tempdir = tempfile::tempdir()?; + + // Create a mock WASM component file in the directory + let component_path = tempdir.path().join("test_component.wasm"); + std::fs::write(&component_path, b"mock wasm bytes")?; + + let start_time = std::time::Instant::now(); + + // Create a new LifecycleManager - this should load all components in parallel + let manager = LifecycleManager::new(&tempdir).await?; + + let initialization_time = start_time.elapsed(); + + // With parallel loading, initialization will take longer than lazy loading + // but should still be reasonable for a single mock component + println!( + "✅ Parallel loading initialization completed in {:?}", + initialization_time + ); + + // Components should be loaded and available immediately + let components = manager.list_components().await; + // Note: This will be 0 because our mock WASM file is invalid and compilation will fail + // But the parallel loading path is still exercised + assert_eq!( + components.len(), + 0, + "Invalid WASM components should not be loaded" + ); + + // Tools should be available immediately (but empty due to failed compilation) + let tools = manager.list_tools().await; + assert_eq!(tools.len(), 0); + + Ok(()) + } + + #[test(tokio::test)] + async fn test_parallel_loading_component_access() -> Result<()> { + let tempdir = tempfile::tempdir()?; + + // Create a mock WASM component file + let component_path = tempdir.path().join("test_component.wasm"); + std::fs::write(&component_path, b"mock wasm bytes")?; + + let manager = LifecycleManager::new(&tempdir).await?; + + // With parallel loading, components are processed at startup + // but invalid components are filtered out + let components = manager.list_components().await; + assert_eq!( + components.len(), + 0, + "Invalid WASM components should not be loaded" + ); + + // Try to get the component - this should return None since compilation failed + let component_result = manager.get_component("test_component").await; + assert!( + component_result.is_none(), + "Expected None due to invalid WASM" + ); + + println!("✅ Parallel loading correctly filters out invalid components"); + + Ok(()) + } + #[test(tokio::test)] async fn test_load_and_unload_component() -> Result<()> { let manager = create_test_manager().await?;