diff --git a/Cargo.lock b/Cargo.lock index 29f3895528..87f328ec70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2187,6 +2187,7 @@ dependencies = [ name = "carbide-scout" version = "0.0.1" dependencies = [ + "axum 0.8.6", "carbide-host-support", "carbide-libmlx", "carbide-machine-validation", @@ -2209,7 +2210,9 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sha2 0.10.9", "smbios-lib", + "tempfile", "thiserror 2.0.17", "tokio", "tokio-stream", diff --git a/crates/rpc/proto/forge.proto b/crates/rpc/proto/forge.proto index f7597f4106..af988f8e41 100644 --- a/crates/rpc/proto/forge.proto +++ b/crates/rpc/proto/forge.proto @@ -6403,6 +6403,7 @@ message ScoutStreamApiBoundMessage { mlx_device.MlxDeviceConfigSyncResponse mlx_device_config_sync_response = 12; mlx_device.MlxDeviceConfigCompareResponse mlx_device_config_compare_response = 13; ScoutStreamAgentPingResponse scout_stream_agent_ping_response = 14; + ScoutRemoteExecResponse scout_remote_exec_response = 15; } } @@ -6432,6 +6433,7 @@ message ScoutStreamScoutBoundMessage { mlx_device.MlxDeviceConfigSyncRequest mlx_device_config_sync_request = 13; mlx_device.MlxDeviceConfigCompareRequest mlx_device_config_compare_request = 14; ScoutStreamAgentPingRequest scout_stream_agent_ping_request = 15; + ScoutRemoteExecRequest scout_remote_exec_request = 16; } } @@ -6493,6 +6495,30 @@ message ScoutStreamAgentPingResponse { } } +// ScoutRemoteExecRequest is sent from carbide-api to the scout agent +// to download files and execute a script on the host. +message ScoutRemoteExecRequest { + string component_type = 1; + string target_version = 2; + string script_url = 3; + uint32 timeout_seconds = 4; + // Files to download before running the script. + // Keys are download URLs, values are expected SHA-256 hex checksums. + // Scout will verify each file after download and reject execution + // if any checksum does not match. + map download_files = 5; +} + +// ScoutRemoteExecResponse is the result of a scout remote execution, +// sent from scout back to carbide-api. +message ScoutRemoteExecResponse { + bool success = 1; + int32 exit_code = 2; + string stdout = 3; + string stderr = 4; + string error = 5; +} + // ScoutStreamConnectionInfo contains information about an // active scout agent connection. message ScoutStreamConnectionInfo { @@ -6696,4 +6722,4 @@ message DPFStateResponse { message GetDPFStateRequest { repeated common.MachineId machine_ids = 1; -} +} \ No newline at end of file diff --git a/crates/scout/Cargo.toml b/crates/scout/Cargo.toml index 396f9c56bd..f21ee5975f 100644 --- a/crates/scout/Cargo.toml +++ b/crates/scout/Cargo.toml @@ -66,10 +66,15 @@ reqwest = { default-features = false, features = [ "rustls-tls", "stream", ], workspace = true } +sha2 = { workspace = true } +tempfile = { workspace = true } futures-util = { workspace = true } prost-types = { workspace = true } x509-parser = { workspace = true } +[dev-dependencies] +axum = { workspace = true } + [build-dependencies] carbide-version = { path = "../version" } diff --git a/crates/scout/src/client.rs b/crates/scout/src/client.rs index 437dc215ce..7e406c7941 100644 --- a/crates/scout/src/client.rs +++ b/crates/scout/src/client.rs @@ -38,3 +38,22 @@ pub(crate) async fn create_forge_client( .map_err(|err| CarbideClientError::TransportError(err.to_string()))?; Ok(client) } + +// create_http_client builds a reqwest HTTP client configured with the same +// mTLS certificates used for gRPC communication with carbide-api. +pub(crate) fn create_http_client(config: &Options) -> CarbideClientResult { + let root_ca = std::fs::read(&config.root_ca)?; + let root_cert = reqwest::Certificate::from_pem(&root_ca) + .map_err(|e| CarbideClientError::TransportError(e.to_string()))?; + + let client_cert = std::fs::read(&config.client_cert)?; + let client_key = std::fs::read(&config.client_key)?; + let identity = reqwest::Identity::from_pem(&[client_cert, client_key].concat()) + .map_err(|e| CarbideClientError::TransportError(e.to_string()))?; + + reqwest::Client::builder() + .add_root_certificate(root_cert) + .identity(identity) + .build() + .map_err(|e| CarbideClientError::TransportError(e.to_string())) +} diff --git a/crates/scout/src/main.rs b/crates/scout/src/main.rs index d20d3620f7..a968637547 100644 --- a/crates/scout/src/main.rs +++ b/crates/scout/src/main.rs @@ -53,6 +53,7 @@ mod discovery; mod machine_validation; mod mlx_device; mod register; +mod remote_exec; mod stream; struct DevEnv { diff --git a/crates/scout/src/remote_exec.rs b/crates/scout/src/remote_exec.rs new file mode 100644 index 0000000000..ee4024fdb7 --- /dev/null +++ b/crates/scout/src/remote_exec.rs @@ -0,0 +1,338 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::path::Path; + +use futures_util::TryStreamExt; +use rpc::forge::{ScoutRemoteExecRequest, ScoutRemoteExecResponse}; +use sha2::{Digest, Sha256}; +use tokio::io::AsyncWriteExt; + +// handle_remote_exec downloads files and a script from carbide-api, +// then executes the script on the host. +pub async fn handle_remote_exec( + client: &reqwest::Client, + request: ScoutRemoteExecRequest, +) -> ScoutRemoteExecResponse { + match run_remote_exec(client, &request).await { + Ok(response) => response, + Err(e) => ScoutRemoteExecResponse { + success: false, + exit_code: -1, + stdout: String::new(), + stderr: String::new(), + error: format!("remote execution failed: {e}"), + }, + } +} + +async fn run_remote_exec( + client: &reqwest::Client, + request: &ScoutRemoteExecRequest, +) -> Result> { + tracing::info!( + "[remote_exec] starting for component={} version={}", + request.component_type, + request.target_version, + ); + + let work_dir = tempfile::tempdir()?; + + // Download the script. + let script_path = download_file(client, &request.script_url, work_dir.path()).await?; + tracing::info!("[remote_exec] script downloaded to {:?}", script_path); + + // Download files and verify checksums. + let download_dir = work_dir.path().join("downloads"); + std::fs::create_dir_all(&download_dir)?; + for (url, expected_sha256) in &request.download_files { + let dest = download_file(client, url, &download_dir).await?; + let actual = sha256_file(&dest).await?; + if actual != *expected_sha256 { + return Err(format!( + "checksum mismatch for {url}: expected {expected_sha256}, got {actual}" + ) + .into()); + } + tracing::info!("[remote_exec] checksum verified for {url}"); + } + + tracing::info!( + "[remote_exec] files downloaded. Executing script {:?}", + script_path, + ); + + // Execute the script with env vars for context. + // kill_on_drop ensures the child process is terminated if the timeout fires, + // preventing orphaned processes and races with tempdir cleanup. + let child = tokio::process::Command::new("sh") + .arg(&script_path) + .env("DOWNLOAD_DIR", &download_dir) + .env("COMPONENT_TYPE", &request.component_type) + .env("TARGET_VERSION", &request.target_version) + .current_dir(work_dir.path()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true) + .spawn()?; + + let timeout = std::time::Duration::from_secs(request.timeout_seconds.into()); + let result = tokio::time::timeout(timeout, child.wait_with_output()).await; + + match result { + Ok(Ok(output)) => { + // from_utf8_lossy always allocates a new string from the stdout/stderr, even if it's valid utf8. + // it's possible the stdout can get quite large, so it's probably best to avoid it in the happy path. + let stdout = String::from_utf8(output.stdout) + .unwrap_or_else(|e| String::from_utf8_lossy(&e.into_bytes()).into_owned()); + let stderr = String::from_utf8(output.stderr) + .unwrap_or_else(|e| String::from_utf8_lossy(&e.into_bytes()).into_owned()); + let exit_code = output.status.code().unwrap_or(-1); + let success = output.status.success(); + + if !stdout.is_empty() { + tracing::info!("[remote_exec] stdout: {stdout}"); + } + if !stderr.is_empty() { + tracing::warn!("[remote_exec] stderr: {stderr}"); + } + + Ok(ScoutRemoteExecResponse { + success, + exit_code, + stdout, + stderr, + error: String::new(), + }) + } + Ok(Err(e)) => Err(format!("failed to execute script: {e}").into()), + Err(_) => Ok(ScoutRemoteExecResponse { + success: false, + exit_code: -1, + stdout: String::new(), + stderr: String::new(), + error: format!("script timed out after {} seconds", request.timeout_seconds), + }), + } +} + +// download_file downloads a file from the given URL into the target directory, +// preserving the filename from the URL path. +async fn download_file( + client: &reqwest::Client, + url: &str, + target_dir: &Path, +) -> Result> { + let parsed = reqwest::Url::parse(url)?; + let segment = parsed + .path_segments() + .and_then(|mut s| s.next_back()) + .filter(|s| !s.is_empty()) + .ok_or_else(|| format!("cannot extract filename from URL: {url}"))?; + + let filename = Path::new(segment) + .file_name() + .ok_or_else(|| format!("invalid filename in URL: {url}"))?; + + let dest = target_dir.join(filename); + + tracing::info!("[remote_exec] downloading {url} -> {dest:?}"); + + let response = client.get(url).send().await?.error_for_status()?; + let mut stream = response.bytes_stream(); + + let mut file = tokio::fs::File::create(&dest).await?; + while let Some(chunk) = stream.try_next().await? { + file.write_all(&chunk).await?; + } + file.flush().await?; + + Ok(dest) +} + +async fn sha256_file(path: &Path) -> Result> { + let bytes = tokio::fs::read(path).await?; + let hash = Sha256::digest(&bytes); + Ok(format!("{hash:x}")) +} + +#[cfg(test)] +mod tests { + use axum::Router; + use axum::routing::get; + use tokio::net::TcpListener; + + use super::*; + + // start_file_server spins up a lightweight HTTP server that serves + // static content at the given routes. Returns the base URL. + async fn start_file_server(routes: Vec<(&'static str, &'static str)>) -> String { + let mut app = Router::new(); + for (path, body) in routes { + app = app.route(path, get(move || async move { body })); + } + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + format!("http://{addr}") + } + + fn sha256_hex(data: &str) -> String { + format!("{:x}", Sha256::digest(data.as_bytes())) + } + + #[tokio::test] + async fn test_successful_upgrade() { + let script = "#!/bin/sh\necho \"upgrade complete\""; + let firmware_content = "binary-data"; + let base = start_file_server(vec![ + ("/scripts/upgrade.sh", script), + ("/firmware/blob.bin", firmware_content), + ]) + .await; + + let fw_url = format!("{base}/firmware/blob.bin"); + let request = ScoutRemoteExecRequest { + component_type: "cpld".into(), + target_version: "1.2.3".into(), + script_url: format!("{base}/scripts/upgrade.sh"), + timeout_seconds: 30, + download_files: [(fw_url, sha256_hex(firmware_content))] + .into_iter() + .collect(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!( + response.success, + "expected success, got error: {}", + response.error + ); + assert_eq!(response.exit_code, 0); + assert!(response.stdout.contains("upgrade complete")); + assert!(response.error.is_empty()); + } + + #[tokio::test] + async fn test_script_failure_returns_exit_code() { + let script = "#!/bin/sh\necho \"something went wrong\" >&2\nexit 42"; + let base = start_file_server(vec![("/scripts/fail.sh", script)]).await; + + let request = ScoutRemoteExecRequest { + component_type: "bios".into(), + target_version: "2.0.0".into(), + script_url: format!("{base}/scripts/fail.sh"), + timeout_seconds: 30, + download_files: Default::default(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!(!response.success); + assert_eq!(response.exit_code, 42); + assert!(response.stderr.contains("something went wrong")); + } + + #[tokio::test] + async fn test_script_timeout() { + let script = "#!/bin/sh\nsleep 60"; + let base = start_file_server(vec![("/scripts/slow.sh", script)]).await; + + let request = ScoutRemoteExecRequest { + component_type: "cpld".into(), + target_version: "1.0.0".into(), + script_url: format!("{base}/scripts/slow.sh"), + timeout_seconds: 1, + download_files: Default::default(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!(!response.success); + assert!(response.error.contains("timed out")); + } + + #[tokio::test] + async fn test_script_receives_env_vars() { + let script = + "#!/bin/sh\necho \"comp=$COMPONENT_TYPE ver=$TARGET_VERSION dir=$DOWNLOAD_DIR\""; + let base = start_file_server(vec![("/scripts/env.sh", script)]).await; + + let request = ScoutRemoteExecRequest { + component_type: "cpldmb".into(), + target_version: "3.4.5".into(), + script_url: format!("{base}/scripts/env.sh"), + timeout_seconds: 30, + download_files: Default::default(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!(response.success, "error: {}", response.error); + assert!(response.stdout.contains("comp=cpldmb")); + assert!(response.stdout.contains("ver=3.4.5")); + assert!(response.stdout.contains("dir=")); + } + + #[tokio::test] + async fn test_download_failure() { + // Point at a URL that will 404. + let base = start_file_server(vec![]).await; + + let request = ScoutRemoteExecRequest { + component_type: "cpld".into(), + target_version: "1.0.0".into(), + script_url: format!("{base}/scripts/nonexistent.sh"), + timeout_seconds: 30, + download_files: Default::default(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!(!response.success); + assert!(!response.error.is_empty()); + } + + #[tokio::test] + async fn test_checksum_mismatch() { + let script = "#!/bin/sh\necho ok"; + let base = start_file_server(vec![ + ("/scripts/upgrade.sh", script), + ("/firmware/fw.bin", "actual-content"), + ]) + .await; + + let fw_url = format!("{base}/firmware/fw.bin"); + let request = ScoutRemoteExecRequest { + component_type: "cpld".into(), + target_version: "1.0.0".into(), + script_url: format!("{base}/scripts/upgrade.sh"), + timeout_seconds: 30, + download_files: [(fw_url, "bad_checksum".to_string())].into_iter().collect(), + }; + + let response = handle_remote_exec(&reqwest::Client::new(), request).await; + + assert!(!response.success); + assert!(response.error.contains("checksum mismatch")); + } +} diff --git a/crates/scout/src/stream.rs b/crates/scout/src/stream.rs index a2d7a4307b..328c7ae5c3 100644 --- a/crates/scout/src/stream.rs +++ b/crates/scout/src/stream.rs @@ -24,7 +24,7 @@ use rpc::protos::forge::{scout_stream_api_bound_message, scout_stream_scout_boun use tokio::sync::mpsc; use crate::cfg::Options; -use crate::{client, mlx_device}; +use crate::{client, mlx_device, remote_exec}; // ScoutStreamError represents errors that can // occur during the life of a scout stream connection. @@ -90,6 +90,9 @@ async fn run_scout_stream_loop( .await .map_err(|e| ScoutStreamError::ClientError(e.to_string()))?; + let http_client = client::create_http_client(options) + .map_err(|e| ScoutStreamError::ClientError(e.to_string()))?; + // Create channels for bidirectional streaming. let (tx, rx) = mpsc::channel::(100); let request_stream = tokio_stream::wrappers::ReceiverStream::new(rx); @@ -132,7 +135,9 @@ async fn run_scout_stream_loop( // Handle the oneof message type from the ScoutStreamScoutBoundMessage, // generating a follow-up ScoutStreamApiBoundMessage "response". - let payload = handle_scout_stream_api_bound_message(flow_uuid, machine_id, request); + let payload = + handle_scout_stream_api_bound_message(&http_client, flow_uuid, machine_id, request) + .await; // And then send the response back to carbide-api. if let Err(e) = tx.send(payload).await { @@ -150,7 +155,8 @@ async fn run_scout_stream_loop( // handle_scout_stream_api_bound_message routes incoming oneof-based requests // to the appropriate handler. -fn handle_scout_stream_api_bound_message( +async fn handle_scout_stream_api_bound_message( + http_client: &reqwest::Client, flow_uuid: uuid::Uuid, machine_id: MachineId, request: scout_stream_scout_bound_message::Payload, @@ -258,6 +264,13 @@ fn handle_scout_stream_api_bound_message( scout_stream_api_bound_message::Payload::MlxDeviceConfigCompareResponse(response), ) } + scout_stream_scout_bound_message::Payload::ScoutRemoteExecRequest(req) => { + let response = remote_exec::handle_remote_exec(http_client, req).await; + ScoutStreamApiBoundMessage::from_flow( + flow_uuid, + scout_stream_api_bound_message::Payload::ScoutRemoteExecResponse(response), + ) + } } }