diff --git a/.clippy.toml b/.clippy.toml index 8987fce2..6b3b5fee 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -1,3 +1,3 @@ -excessive-nesting-threshold = 4 +excessive-nesting-threshold = 6 too-many-arguments-threshold = 10 allowed-idents-below-min-chars = ["..", "k", "v", "f", "re", "id", "Ok", "'_"] diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index fc9c207a..1d6291d6 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -17,12 +17,12 @@ jobs: - name: Install Rust + components uses: actions-rust-lang/setup-rust-toolchain@v1 with: - toolchain: 1.87 + toolchain: 1.90.0 components: rustfmt,clippy - name: Install code coverage uses: taiki-e/install-action@cargo-llvm-cov - name: Run syntax and style tests - run: cargo clippy --no-default-features --features=test --all-targets -- -D warnings + run: cargo clippy --all-targets -- -D warnings - name: Run format test run: cargo fmt --check - name: Run integration tests w/ coverage report @@ -57,15 +57,11 @@ jobs: uv pip install eclipse-zenoh -p ~/.local/share/base . ~/.local/share/base/bin/activate maturin develop --uv - - name: Run smoke test + - name: Run Python tests env: RUST_BACKTRACE: full run: | . ~/.local/share/base/bin/activate - python tests/extra/python/smoke_test.py -- tests/.tmp - - name: Run agent test - env: - RUST_BACKTRACE: full - run: | - . ~/.local/share/base/bin/activate - python tests/extra/python/agent_test.py -- tests/.tmp + for py_file in tests/extra/python/*.py; do + python "$py_file" -- tests/.tmp + done diff --git a/.vscode/settings.json b/.vscode/settings.json index 695f87c2..ae1d1abe 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,14 +9,10 @@ ], "files.autoSave": "off", "files.insertFinalNewline": true, - "gitlens.showWelcomeOnInstall": false, "gitlens.showWhatsNewAfterUpgrades": false, "lldb.consoleMode": "evaluate", - "rust-analyzer.cargo.features": [ - "test" - ], - "rust-analyzer.cargo.noDefaultFeatures": true, "rust-analyzer.check.command": "clippy", + "rust-analyzer.checkOnSave": true, "rust-analyzer.runnables.extraTestBinaryArgs": [ "--nocapture" ], diff --git a/Cargo.toml b/Cargo.toml index 2232bd2c..be17fcef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ glob = "0.3.1" heck = "0.5.0" # convert bytes to hex strings hex = "0.4.3" +hostname = "0.4.1" # hashmaps that preserve insertion order indexmap = { version = "2.9.0", features = ["serde"] } # utilities for iterables e.g. cartesian products @@ -82,7 +83,7 @@ tokio = { version = "1.41.0", features = ["full"] } # utilities for async calls tokio-util = "0.7.13" # automated CFFI + bindings in other languages -uniffi = { version = "0.29.1", features = ["cli", "tokio"] } +uniffi = { version = "0.29.4", features = ["cli", "tokio"] } # shared, distributed memory via communication zenoh = "1.3.4" diff --git a/cspell.json b/cspell.json index 99f8c978..42697206 100644 --- a/cspell.json +++ b/cspell.json @@ -80,7 +80,9 @@ "wasi", "patchelf", "itertools", - "colinianking" + "colinianking", + "itertools", + "pathset", ], "useGitignore": false, "ignorePaths": [ diff --git a/src/core/crypto.rs b/src/core/crypto.rs index 0c606e19..4a109500 100644 --- a/src/core/crypto.rs +++ b/src/core/crypto.rs @@ -108,3 +108,43 @@ pub fn make_random_hash() -> String { rand::rng().fill_bytes(&mut bytes); hex::encode(bytes) } + +#[cfg(test)] +mod tests { + #![expect(clippy::panic_in_result_fn, reason = "OK in tests.")] + use crate::{ + core::crypto::{hash_buffer, hash_dir, hash_file}, + uniffi::error::Result, + }; + use std::fs::read; + + #[test] + fn consistent_hash() -> Result<()> { + let filepath = "./tests/extra/data/images/subject.jpeg"; + assert_eq!( + hash_file(filepath)?, + hash_buffer(&read(filepath)?), + "Checksum not consistent." + ); + Ok(()) + } + + #[test] + fn complex_hash() -> Result<()> { + let dirpath = "./tests/extra/data/images"; + assert_eq!( + hash_dir(dirpath)?, + "6c96a478ea25e34fab045bc82858a2980b2cfb22db32e83c01349a8e7ed3b42c".to_owned(), + "Directory checksum didn't match." + ); + Ok(()) + } + + #[test] + fn internal_invalid_filepath() { + assert!( + hash_file("nonexistent_file.txt").is_err(), + "Did not raise an invalid filepath error." + ); + } +} diff --git a/src/core/error.rs b/src/core/error.rs index d264814b..52d7449e 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -121,10 +121,23 @@ impl fmt::Debug for OrcaError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.kind { Kind::AgentCommunicationFailure { backtrace, .. } + | Kind::EmptyDir { backtrace, .. } | Kind::FailedToStartPod { backtrace, .. } + | Kind::FailedToExtractRunInfo { backtrace, .. } | Kind::IncompletePacket { backtrace, .. } | Kind::InvalidPath { backtrace, .. } + | Kind::InvalidIndex { backtrace, .. } + | Kind::InvalidInputSpecNodeNotInGraph { backtrace, .. } + | Kind::InvalidOutputSpecKeyNotInNode { backtrace, .. } + | Kind::InvalidOutputSpecNodeNotInGraph { backtrace, .. } + | Kind::KeyMissing { backtrace, .. } | Kind::MissingInfo { backtrace, .. } + | Kind::FailedToGetLabelHashFromFileName { backtrace, .. } + | Kind::FailedToGetPodJobOutput { backtrace, .. } + | Kind::PipelineValidationErrorMissingKeys { backtrace, .. } + | Kind::PodJobProcessingError { backtrace, .. } + | Kind::PodJobSubmissionFailed { backtrace, .. } + | Kind::UnexpectedPathType { backtrace, .. } | Kind::BollardError { backtrace, .. } | Kind::ChronoParseError { backtrace, .. } | Kind::DOTError { backtrace, .. } diff --git a/src/core/graph.rs b/src/core/graph.rs index 8772dcc6..4dd9f643 100644 --- a/src/core/graph.rs +++ b/src/core/graph.rs @@ -1,5 +1,5 @@ use crate::{ - core::{pipeline::PipelineNode, util::get}, + core::{model::pipeline::PipelineNode, util::get}, uniffi::{error::Result, model::pipeline::Kernel}, }; use dot_parser::ast::Graph as DOTGraph; @@ -10,7 +10,6 @@ use petgraph::{ use std::collections::HashMap; #[expect( - clippy::needless_pass_by_value, clippy::panic_in_result_fn, clippy::panic, reason = " @@ -20,15 +19,20 @@ use std::collections::HashMap; )] pub fn make_graph( input_dot: &str, - metadata: HashMap, + metadata: &HashMap, ) -> Result> { let graph = DiGraph::::from_dot_graph(DOTGraph::try_from(input_dot)?).map( - |_, node| PipelineNode { - name: node.id.clone(), - kernel: get(&metadata, &node.id) - .unwrap_or_else(|error| panic!("{error}")) - .clone(), + |node_idx, node| { + let node_id_without_quotes = node.id.replace('"', ""); + PipelineNode { + hash: String::new(), + kernel: get(metadata, &node_id_without_quotes) + .unwrap_or_else(|error| panic!("{error}")) + .clone(), + label: node_id_without_quotes.clone(), + node_idx, + } }, |_, _| (), ); diff --git a/src/core/mod.rs b/src/core/mod.rs index e4bf2a84..9faec897 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,40 +1,14 @@ -macro_rules! inner_attr_to_each { - { #!$attr:tt $($it:item)* } => { - $( - #$attr - $it - )* - } -} - pub(crate) mod error; pub(crate) mod graph; -pub(crate) mod pipeline; pub(crate) mod store; pub(crate) mod util; pub(crate) mod validation; -inner_attr_to_each! { - #![cfg(feature = "default")] - pub(crate) mod crypto; - pub(crate) mod model; - pub(crate) mod operator; - pub(crate) mod orchestrator; -} +pub(crate) mod crypto; +/// Model definition for orcapod +pub(crate) mod model; +pub(crate) mod operator; +pub(crate) mod orchestrator; -#[cfg(feature = "test")] -inner_attr_to_each! { - #![cfg_attr( - feature = "test", - allow( - missing_docs, - clippy::missing_errors_doc, - clippy::missing_panics_doc, - reason = "Documentation not necessary since private API.", - ), - )] - pub mod crypto; - pub mod model; - pub mod operator; - pub mod orchestrator; -} +/// Pipeline runner module +pub mod pipeline_runner; diff --git a/src/core/model/mod.rs b/src/core/model/mod.rs index b21a1728..94fa83d2 100644 --- a/src/core/model/mod.rs +++ b/src/core/model/mod.rs @@ -5,31 +5,37 @@ use serde::{Serialize, Serializer}; use serde_yaml::{self, Value}; use std::{ collections::{BTreeMap, HashMap}, + fmt::Debug, hash::BuildHasher, result, }; -/// Converts a model instance into a consistent yaml. -/// -/// # Errors -/// -/// Will return `Err` if there is an issue converting an `instance` into YAML (w/o annotation). -pub fn to_yaml(instance: &T) -> Result { - let mapping: IndexMap = serde_yaml::from_str(&serde_yaml::to_string(instance)?)?; // cast to map - let mut yaml = serde_yaml::to_string( - &mapping - .iter() - .filter_map(|(k, v)| match &**k { - "annotation" | "hash" => None, - "pod" | "pod_job" => Some((k, v["hash"].clone())), - _ => Some((k, v.clone())), - }) - .collect::>(), - )?; // skip fields and convert refs to hash pointers - yaml.insert_str( - 0, - &format!("class: {}\n", get_type_name::().to_snake_case()), - ); // replace class at top - Ok(yaml) + +/// Trait to handle serialization to yaml for `OrcaPod` models +pub trait ToYaml: Serialize + Sized + Debug { + /// Serializes the instance to a YAML string. + /// # Errors + /// Will return `Err` if it fail to serialize instance to string + fn to_yaml(&self) -> Result { + let mapping: IndexMap = serde_yaml::from_str(&serde_yaml::to_string(self)?)?; // cast to map + let mut yaml = serde_yaml::to_string( + &mapping + .iter() + .filter_map(|(k, v)| Self::process_field(k, v)) + .collect::>(), + )?; // skip fields and convert refs to hash pointers + yaml.insert_str( + 0, + &format!("class: {}\n", get_type_name::().to_snake_case()), + ); // replace class at top + Ok(yaml) + } + + /// Filter out which field to serialize and which to omit + /// + /// # Returns + /// (`field_name`, `field_value`): to be pass to `to_yaml` for serialization + /// None: to skip + fn process_field(field_name: &str, field_value: &Value) -> Option<(String, Value)>; } pub fn serialize_hashmap( @@ -57,4 +63,5 @@ where sorted.serialize(serializer) } +pub mod pipeline; pub mod pod; diff --git a/src/core/model/pipeline.rs b/src/core/model/pipeline.rs new file mode 100644 index 00000000..4532e2c8 --- /dev/null +++ b/src/core/model/pipeline.rs @@ -0,0 +1,501 @@ +use std::{ + backtrace::Backtrace, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + result, +}; + +use crate::{ + core::{crypto::hash_buffer, model::ToYaml}, + uniffi::{ + error::{Kind, OrcaError, Result, selector}, + model::{ + packet::PathSet, + pipeline::{Kernel, NodeURI, Pipeline, PipelineJob}, + }, + }, +}; +use itertools::Itertools as _; +use petgraph::{ + Direction::Incoming, + graph::{self, NodeIndex}, +}; +use serde::{Deserialize, Serialize, ser::SerializeStruct as _}; +use snafu::OptionExt as _; + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct PipelineNode { + // Hash that represent the node + pub hash: String, + /// Kernel associated with the node + pub kernel: Kernel, + /// User provided label for the node + pub label: String, + /// This is meant for internal use only to track the node index in the graph + pub node_idx: NodeIndex, +} + +impl Pipeline { + /// Validate the pipeline to ensure that, based on user labels: + /// 1. Each node's `input_spec` is covered by either its parent nodes or the pipeline's `input_spec` + pub(crate) fn validate(&self) -> Result<()> { + // For verification we check that each node has it's input_spec covered by either it's parent or input_spec of the pipeline + // Build a map from input_spec where HashMap, + let mut keys_covered_by_input_spec_lut: HashMap<&String, HashSet<&String>> = + HashMap::<&String, HashSet<&String>>::new(); + for node_uris in self.input_spec.values() { + for node_uri in node_uris { + keys_covered_by_input_spec_lut + .entry(&node_uri.node_id) + .or_default() + .insert(&node_uri.key); + } + } + + // Iterate over each node in the graph and verify that its input spec is met + for node_idx in self.graph.node_indices() { + self.validate_valid_input_spec( + node_idx, + keys_covered_by_input_spec_lut.get(&self.graph[node_idx].hash), + )?; + } + + // Build a LUT for all node_hash to idx + let node_hash_to_idx_lut: HashMap<&String, NodeIndex> = self + .graph + .node_indices() + .map(|idx| (&self.graph[idx].hash, idx)) + .collect(); + + // Validate that all output_keys are valid + self.output_spec.iter().try_for_each(|(_, node_uri)| { + if !self + .get_output_spec_for_node(*node_hash_to_idx_lut.get(&node_uri.node_id).context( + selector::InvalidOutputSpecNodeNotInGraph { + node_name: node_uri.node_id.clone(), + }, + )?) + .contains(&node_uri.key) + { + return Err(OrcaError { + kind: Kind::InvalidOutputSpecKeyNotInNode { + node_name: node_uri.node_id.clone(), + key: node_uri.key.clone(), + backtrace: Some(Backtrace::capture()), + }, + }); + } + Ok(()) + })?; + + Ok(()) + } + + /// Validates that the input spec for a given node is valid based on its parents and the input spec of the pipeline + fn validate_valid_input_spec( + &self, + node_idx: NodeIndex, + keys_covered_by_input_spec: Option<&HashSet<&String>>, + ) -> Result<()> { + // We need to get the input spec of the current node and build the packet based on the + // parent nodes output spec + input + + // Get the parent nodes input specs and combine them into + let incoming_packet_keys = self + .get_parent_node_indices(node_idx) + .flat_map(|parent_idx| self.get_output_spec_for_node(parent_idx)) + .collect::>(); + + // Get this node input_spec + let missing_keys: HashSet<&String> = self + .get_input_spec_for_node(node_idx) + .into_iter() + .filter(|expected_key| { + !(incoming_packet_keys.contains(expected_key) + || keys_covered_by_input_spec.is_some_and(|keys| keys.contains(expected_key))) + }) + .collect(); + + // Verify that there are no missing keys, otherwise return error + if !missing_keys.is_empty() { + return Err(OrcaError { + kind: Kind::PipelineValidationErrorMissingKeys { + node_name: self.graph[node_idx].label.clone(), + missing_keys: missing_keys.into_iter().cloned().collect(), + backtrace: Some(Backtrace::capture()), + }, + }); + } + Ok(()) + } + + fn get_input_spec_for_node(&self, node_idx: NodeIndex) -> HashSet<&String> { + match &self.graph[node_idx].kernel { + Kernel::Pod { pod } => pod.input_spec.keys().collect(), + Kernel::JoinOperator => { + // JoinOperator input_spec is derived from its parents + self.get_parent_node_indices(node_idx) + .flat_map(|parent_idx| self.get_output_spec_for_node(parent_idx)) + .collect() + } + Kernel::MapOperator { mapper } => mapper.map.keys().collect(), + } + } + + fn get_output_spec_for_node(&self, node_idx: NodeIndex) -> HashSet<&String> { + match &self.graph[node_idx].kernel { + Kernel::Pod { pod } => pod.output_spec.keys().collect(), + Kernel::JoinOperator => { + // JoinOperator output_spec is derived from its parents + self.get_parent_node_indices(node_idx) + .flat_map(|parent_idx| self.get_output_spec_for_node(parent_idx)) + .collect() + } + Kernel::MapOperator { mapper } => mapper.map.values().collect(), + } + } + + /// Function to get the parents of a node + pub(crate) fn get_node_parents( + &self, + node: &PipelineNode, + ) -> Result> { + // Find the NodeIndex for the given node_key + let node_idx = self + .graph + .node_indices() + .find(|&idx| self.graph[idx] == *node) + .ok_or(OrcaError { + kind: Kind::KeyMissing { + key: node.label.clone(), + backtrace: Some(Backtrace::capture()), + }, + })?; + + Ok(self + .get_parent_node_indices(node_idx) + .map(|parent_idx| &self.graph[parent_idx])) + } + + /// Return a vec of `node_names` that takes in inputs based on the `input_spec` + pub(crate) fn get_input_nodes(&self) -> HashSet<&String> { + let mut input_nodes = HashSet::new(); + + self.input_spec.iter().for_each(|(_, node_uris)| { + for node_uri in node_uris { + input_nodes.insert(&node_uri.node_id); + } + }); + + input_nodes + } + + fn get_parent_node_indices(&self, node_idx: NodeIndex) -> impl Iterator { + self.graph.neighbors_directed(node_idx, Incoming) + } + + /// Find the leaf nodes in the graph (nodes with no outgoing edges) + /// # Returns + /// A vector of `NodeIndex` representing the leaf nodes in the graph + pub fn find_leaf_nodes(&self) -> Vec { + self.graph + .node_indices() + .filter(|&idx| { + self.graph + .neighbors_directed(idx, petgraph::Direction::Outgoing) + .next() + .is_none() + }) + .collect() + } + + /// Compute the hash for each node in the graph which is defined as the hash of its kernel + the hashes of its parents + pub(crate) fn compute_hash_for_node_and_parents( + node_idx: NodeIndex, + input_spec: &HashMap>, + graph: &mut graph::Graph, + ) { + // Collect parent indices first to avoid borrowing issues + let parent_indices: Vec = graph.neighbors_directed(node_idx, Incoming).collect(); + + // Sort the parent hashes to ensure consistent ordering + let mut parent_hashes: Vec = if parent_indices.is_empty() { + // This is parent node, thus we will need to use the input_spec to generate a unique hash for the node + // Find all the input keys that map to this node + let input_keys = input_spec.iter().filter_map(|(input_key, node_uris)| { + node_uris.iter().find_map(|node_uri| { + (node_uri.node_id == graph[node_idx].label).then(|| input_key.clone()) + }) + }); + + input_keys.collect() + } else { + parent_indices + .into_iter() + .map(|parent_idx| { + // Check if hash has been computed for this node, if not trigger computation + if graph[parent_idx].hash.is_empty() { + // Recursive call to compute the parent's hash + Self::compute_hash_for_node_and_parents(parent_idx, input_spec, graph); + } + graph[parent_idx].hash.clone() + }) + .collect() + }; + + parent_hashes.sort(); + + // Combine the node's kernel hash + the parent_hashes by concatenation only if there are parents hashes, else it is just the kernel hash + if parent_hashes.is_empty() { + } else { + let hash_for_node = format!( + "{}{}", + &graph[node_idx].kernel.get_hash(), + parent_hashes.into_iter().join("") + ); + graph[node_idx].hash = hash_buffer(hash_for_node.as_bytes()); + } + } + + fn to_dot_lex(&self) -> String { + // Get all the nodes and their children in lexicographical order + let nodes_and_edges = self.graph.node_indices().fold( + BTreeMap::<&String, BTreeSet<&String>>::new(), + |mut acc, node_idx| { + let children = self + .graph + .neighbors_directed(node_idx, petgraph::Direction::Outgoing) + .map(|child_idx| &self.graph[child_idx].hash) + .collect::>(); + acc.insert(&self.graph[node_idx].hash, children); + acc + }, + ); + + // Build the dot representation string by + let mut lines = Vec::new(); + for (node, node_children) in nodes_and_edges { + if node_children.is_empty() { + lines.push(format!(" \"{node}\"")); + } else { + for child in node_children { + lines.push(format!(" \"{node}\" -> \"{child}\"")); + } + } + } + + // Convert lines into a single string with a proper new line between each entry + format!("digraph {{\n{}\n}}", lines.join("\n")) + } + + /// Get a `BTreeMap` of <`kernel_hash`, `BTreeSet`<`node_hashes`>> for all nodes in the graph. Mainly use for serialization + pub(crate) fn get_kernel_to_node_lut(&self) -> BTreeMap> { + self.graph.node_indices().fold( + BTreeMap::>::new(), + |mut acc, node_idx| { + acc.entry(self.graph[node_idx].kernel.get_hash().to_owned()) + .or_default() + .insert(self.graph[node_idx].hash.clone()); + acc + }, + ) + } + + pub(crate) fn get_kernel_lut(&self) -> HashSet<&Kernel> { + self.graph + .node_indices() + .fold(HashSet::<&Kernel>::new(), |mut acc, node_idx| { + acc.insert(&self.graph[node_idx].kernel); + acc + }) + } + + /// Get a `HashMap` of <`node_hash`, `node_label`> for all nodes in the graph if label is not empty. Mainly use for serialization + pub(crate) fn get_label_lut(&self) -> impl Iterator { + self.graph.node_indices().filter_map(|node_idx| { + let label = &self.graph[node_idx].label; + if label.is_empty() { + None + } else { + Some((&self.graph[node_idx].hash, label)) + } + }) + } +} + +impl Serialize for Pipeline { + fn serialize(&self, serializer: S) -> result::Result + where + S: serde::Serializer, + { + let mut state = serializer.serialize_struct("Pipeline", 4)?; + state.serialize_field("kernel_lut", &self.get_kernel_to_node_lut())?; + state.serialize_field("dot", &self.to_dot_lex())?; + + // Input spec needs to be sorted for consistent serialization + let input_spec_sorted: BTreeMap<_, Vec> = self + .input_spec + .iter() + .map(|(k, v)| { + let mut sorted_v = v.clone(); + sorted_v.sort(); + (k, sorted_v) + }) + .collect(); + state.serialize_field("input_spec", &input_spec_sorted)?; + state.serialize_field("output_spec", &self.output_spec)?; + state.end() + } +} + +impl ToYaml for Pipeline { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + match field_name { + "hash" | "annotation" => None, // Skip annotation field + _ => Some((field_name.to_owned(), field_value.clone())), + } + } +} + +impl PipelineJob { + /// Helpful function to get the input packet for input nodes of the pipeline based on the `pipeline_job` an`pipeline_spec`ec + /// # Errors + /// Will return `Err` if there is an issue getting the input packet per node. + /// # Returns + /// A `HashMap` where the key is the node name and the value is a vector of `HashMap` representing the input packets for that node. + pub fn get_input_packet_per_node( + &self, + ) -> Result>>> { + // For each node in the input specification, we will iterate over its mapping + // nodes_input_spec contains > + let mut nodes_input_spec = HashMap::new(); + for (input_key, node_uris) in &self.pipeline.input_spec { + for node_uri in node_uris { + let input_path_sets = self.input_packet.get(input_key).ok_or(OrcaError { + kind: Kind::KeyMissing { + key: input_key.clone(), + backtrace: Some(Backtrace::capture()), + }, + })?; + // There shouldn't be a duplicate key in the input packet as this will be handle by pipeline verify + let input_spec = nodes_input_spec + .entry(&node_uri.node_id) + .or_insert_with(HashMap::new); + input_spec.insert(&node_uri.key, input_path_sets); + } + } + + // For each node, compute the cartesian product of the path_sets for each unique combination of keys + let node_input_packets = nodes_input_spec + .into_iter() + .map(|(node_id, input_node_keys)| { + // We need to pull them out at the same time to ensure the key order is preserve to match the cartesian product + let (keys, values): (Vec<_>, Vec<_>) = input_node_keys.into_iter().unzip(); + + // Covert each combo into a packet + let packets = values + .into_iter() + .multi_cartesian_product() + .map(|combo| { + keys.iter() + .copied() + .zip(combo) + .map(|(key, pathset)| (key.to_owned(), pathset.to_owned())) + .collect::>() + }) + .collect::>>(); + + (node_id.to_owned(), packets) + }) + .collect::>(); + + Ok(node_input_packets) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + core::model::ToYaml as _, + uniffi::{ + error::Result, + model::{ + Annotation, + pipeline::{NodeURI, Pipeline}, + }, + operator::MapOperator, + }, + }; + use indoc::indoc; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + + #[test] + fn to_yaml() -> Result<()> { + let pipeline = Pipeline::new( + indoc! {" + digraph { + A -> B -> C + } + "}, + &HashMap::from([ + ( + "A".into(), + MapOperator::new(HashMap::from([("node_key_1".into(), "node_key_2".into())]))? + .into(), + ), + ( + "B".into(), + MapOperator::new(HashMap::from([("node_key_2".into(), "node_key_1".into())]))? + .into(), + ), + ( + "C".into(), + MapOperator::new(HashMap::from([("node_key_1".into(), "node_key_2".into())]))? + .into(), + ), + ]), + HashMap::from([( + "pipeline_key_1".into(), + vec![NodeURI { + node_id: "A".into(), + key: "node_key_1".into(), + }], + )]), + HashMap::new(), + Some(Annotation { + name: "test".into(), + version: "0.1".into(), + description: "Test pipeline".into(), + }), + )?; + + assert_eq!( + pipeline.to_yaml()?, + indoc! {r#" + class: pipeline + kernel_lut: + 2980eb39e3702442cc31656d6ec3995f91680ab042a27160a00ffe33b91419af: + - 4b498582ed57ca6a10809d7480bd3f159542ad139e402698e3f525fc6b0d4dea + c8f036079b69beee914434c1e01be638972ce05cd2e640fc1e9be7bf3d9e76be: + - 368c7a517f3fbdd7cab10c90ebc44e44765fc33f66b4f4f5151a6bee322d8217 + - 6c84111298d0cfe811dff3f10ab444795c0a9d60609ba1b9391c45e642a69afa + dot: |- + digraph { + "368c7a517f3fbdd7cab10c90ebc44e44765fc33f66b4f4f5151a6bee322d8217" + "4b498582ed57ca6a10809d7480bd3f159542ad139e402698e3f525fc6b0d4dea" -> "368c7a517f3fbdd7cab10c90ebc44e44765fc33f66b4f4f5151a6bee322d8217" + "6c84111298d0cfe811dff3f10ab444795c0a9d60609ba1b9391c45e642a69afa" -> "4b498582ed57ca6a10809d7480bd3f159542ad139e402698e3f525fc6b0d4dea" + } + input_spec: + pipeline_key_1: + - node_id: 6c84111298d0cfe811dff3f10ab444795c0a9d60609ba1b9391c45e642a69afa + key: node_key_1 + output_spec: {} + "#}, + ); + + Ok(()) + } +} diff --git a/src/core/model/pod.rs b/src/core/model/pod.rs index a2ebb0fd..fdcafc13 100644 --- a/src/core/model/pod.rs +++ b/src/core/model/pod.rs @@ -48,3 +48,218 @@ where }, ) } + +#[cfg(test)] +mod tests { + #![expect(clippy::unwrap_used, reason = "OK in tests.")] + use indoc::indoc; + use std::sync::{Arc, LazyLock}; + use std::{collections::HashMap, path::PathBuf}; + + use crate::core::model::ToYaml as _; + use crate::uniffi::model::packet::{Blob, BlobKind, PathSet, URI}; + use crate::uniffi::model::pod::PodResult; + use crate::uniffi::orchestrator::PodStatus; + use crate::uniffi::{ + error::Result, + model::{ + Annotation, + packet::PathInfo, + pod::{Pod, PodJob, RecommendSpecs}, + }, + }; + + use pretty_assertions::assert_eq; + + static TEST_FILE_NAMESPACE_LOOKUP: LazyLock> = LazyLock::new(|| { + HashMap::from([ + ("input".into(), PathBuf::from("tests/extra/data/input_txt")), + ("output".into(), PathBuf::from("tests/extra/data/output")), + ]) + }); + + fn basic_pod() -> Result { + Pod::new( + Some(Annotation { + name: "test".into(), + version: "0.1".into(), + description: "Basic pod for testing hashing and yaml serialization".into(), + }), + "alpine:3.14".into(), + vec!["cp", "/input/input.txt", "/output/output.txt"] + .into_iter() + .map(String::from) + .collect(), + HashMap::from([( + "input_txt".into(), + PathInfo { + path: "/input/input.txt".into(), + match_pattern: r".*\.txt".into(), + }, + )]), + "/output".into(), + HashMap::from([( + "output_txt".into(), + PathInfo { + path: "output.txt".into(), + match_pattern: r".*\.txt".into(), + }, + )]), + RecommendSpecs { + cpus: 0.20, + memory: 128 << 20, + }, + None, + ) + } + + fn basic_pod_job() -> Result { + let pod = Arc::new(basic_pod()?); + PodJob::new( + Some(Annotation { + name: "test_job".into(), + version: "0.1".into(), + description: "Basic pod job for testing hashing and yaml serialization".into(), + }), + Arc::clone(&pod), + HashMap::from([( + "input_txt".into(), + PathSet::Unary(Blob::new( + BlobKind::File, + URI { + namespace: "input".into(), + path: "cat.txt".into(), + }, + )), + )]), + URI { + namespace: "output".into(), + path: "".into(), + }, + pod.recommend_specs.cpus, + pod.recommend_specs.memory, + Some(HashMap::from([("FAKE_ENV".into(), "FakeValue".into())])), + &TEST_FILE_NAMESPACE_LOOKUP, + ) + } + + fn basic_pod_result() -> Result { + PodResult::new( + Some(Annotation { + name: "test".into(), + version: "0.1".into(), + description: "Basic Result for testing hashing and yaml serialization".into(), + }), + basic_pod_job()?.into(), + "randomly_assigned_name".into(), + PodStatus::Completed, + 1_737_922_307, + 1_737_925_907, + &TEST_FILE_NAMESPACE_LOOKUP, + "example_logs".to_owned(), + ) + } + + #[test] + fn pod_hash() { + assert_eq!( + basic_pod().unwrap().hash, + "b5574e2efdf26361e8e8e886389a250cfbfcceed08b29325a78fd738cbb2a1b8", + "Hash didn't match." + ); + } + + #[test] + fn pod_to_yaml() { + assert_eq!( + basic_pod().unwrap().to_yaml().unwrap(), + indoc! {r" + class: pod + image: alpine:3.14 + command: + - cp + - /input/input.txt + - /output/output.txt + input_spec: + input_txt: + path: /input/input.txt + match_pattern: .*\.txt + output_dir: /output + output_spec: + output_txt: + path: output.txt + match_pattern: .*\.txt + gpu_requirements: null + "}, + "YAML serialization didn't match." + ); + } + + #[test] + fn pod_job_hash() { + assert_eq!( + basic_pod_job().unwrap().hash, + "80348a4ef866a9dfc1a5d0a48467a6592ef2ed9e8de67930d64afefbb395f1c6", + "Hash didn't match." + ); + } + + #[test] + fn pod_job_to_yaml() { + assert_eq!( + basic_pod_job().unwrap().to_yaml().unwrap(), + indoc! {" + class: pod_job + pod: b5574e2efdf26361e8e8e886389a250cfbfcceed08b29325a78fd738cbb2a1b8 + input_packet: + input_txt: + kind: File + location: + namespace: input + path: cat.txt + checksum: 175cc6f362b2f75acd08a373e000144fdb8d14a833d4b70fd743f16a7039103f + output_dir: + namespace: output + path: '' + cpu_limit: 0.2 + memory_limit: 134217728 + env_vars: + FAKE_ENV: FakeValue + "}, + "YAML serialization didn't match." + ); + } + + #[test] + fn pod_result_hash() { + assert_eq!( + basic_pod_result().unwrap().hash, + "92809a4ce13b4fe8c8dcdcf2b48dd14a9dd885593fe3ab5d9809d27bc9a16354", + "Hash didn't match." + ); + } + + #[test] + fn pod_result_to_yaml() { + assert_eq!( + basic_pod_result().unwrap().to_yaml().unwrap(), + indoc! {" + class: pod_result + pod_job: 80348a4ef866a9dfc1a5d0a48467a6592ef2ed9e8de67930d64afefbb395f1c6 + output_packet: + output_txt: + kind: File + location: + namespace: output + path: output.txt + checksum: 175cc6f362b2f75acd08a373e000144fdb8d14a833d4b70fd743f16a7039103f + assigned_name: randomly_assigned_name + status: Completed + created: 1737922307 + terminated: 1737925907 + logs: example_logs + "}, + "YAML serialization didn't match." + ); + } +} diff --git a/src/core/operator.rs b/src/core/operator.rs index 64db6206..55f72bed 100644 --- a/src/core/operator.rs +++ b/src/core/operator.rs @@ -1,4 +1,7 @@ -use crate::uniffi::{error::Result, model::packet::Packet}; +use crate::{ + core::model::ToYaml, + uniffi::{error::Result, model::packet::Packet, operator::MapOperator}, +}; use async_trait; use itertools::Itertools as _; use std::{clone::Clone, collections::HashMap, iter::IntoIterator, sync::Arc}; @@ -6,7 +9,7 @@ use tokio::sync::Mutex; #[async_trait::async_trait] pub trait Operator { - async fn next(&self, stream_name: String, packet: Packet) -> Result>; + async fn process_packet(&self, stream_name: String, packet: Packet) -> Result>; } pub struct JoinOperator { @@ -25,7 +28,7 @@ impl JoinOperator { #[async_trait::async_trait] impl Operator for JoinOperator { - async fn next(&self, stream_name: String, packet: Packet) -> Result> { + async fn process_packet(&self, stream_name: String, packet: Packet) -> Result> { let mut received_packets = self.received_packets.lock().await; received_packets .entry(stream_name.clone()) @@ -61,19 +64,9 @@ impl Operator for JoinOperator { } } -pub struct MapOperator { - map: HashMap, -} - -impl MapOperator { - pub fn new(map: &HashMap) -> Self { - Self { map: map.clone() } - } -} - #[async_trait::async_trait] impl Operator for MapOperator { - async fn next(&self, _: String, packet: Packet) -> Result> { + async fn process_packet(&self, _: String, packet: Packet) -> Result> { Ok(vec![ packet .iter() @@ -89,3 +82,232 @@ impl Operator for MapOperator { ]) } } + +impl ToYaml for MapOperator { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + match field_name { + "hash" => None, + _ => Some((field_name.to_owned(), field_value.clone())), + } + } +} + +#[cfg(test)] +mod tests { + #![expect(clippy::panic_in_result_fn, reason = "OK in tests.")] + + use crate::{ + core::operator::{JoinOperator, MapOperator, Operator}, + uniffi::{ + error::Result, + model::packet::{Blob, BlobKind, Packet, PathSet, URI}, + }, + }; + use std::{collections::HashMap, path::PathBuf}; + + fn make_packet_key(key_name: String, filepath: String) -> (String, PathSet) { + ( + key_name, + PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: "default".into(), + path: PathBuf::from(filepath), + }, + checksum: String::new(), + }), + ) + } + + async fn next_batch( + operator: impl Operator, + packets: Vec<(String, Packet)>, + ) -> Result> { + let mut next_packets = vec![]; + for (stream_name, packet) in packets { + next_packets.extend(operator.process_packet(stream_name, packet).await?); + } + Ok(next_packets) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn join_once() -> Result<()> { + let operator = JoinOperator::new(2); + + let left_stream = (0..3) + .map(|i| { + ( + "left".into(), + Packet::from([make_packet_key( + "subject".into(), + format!("left/subject{i}.png"), + )]), + ) + }) + .collect::>(); + + let right_stream = (0..2) + .map(|i| { + ( + "right".into(), + Packet::from([make_packet_key( + "style".into(), + format!("right/style{i}.t7"), + )]), + ) + }) + .collect::>(); + + let mut input_streams = left_stream; + input_streams.extend(right_stream); + + assert_eq!( + next_batch(operator, input_streams).await?, + vec![ + Packet::from([ + make_packet_key("subject".into(), "left/subject0.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject1.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject2.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject0.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject1.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject2.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + ], + "Unexpected streams." + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn join_spotty() -> Result<()> { + let operator = JoinOperator::new(2); + + assert_eq!( + operator + .process_packet( + "right".into(), + Packet::from([make_packet_key("style".into(), "right/style0.t7".into())]) + ) + .await?, + vec![], + "Unexpected streams." + ); + + assert_eq!( + operator + .process_packet( + "right".into(), + Packet::from([make_packet_key("style".into(), "right/style1.t7".into())]) + ) + .await?, + vec![], + "Unexpected streams." + ); + + assert_eq!( + operator + .process_packet( + "left".into(), + Packet::from([make_packet_key( + "subject".into(), + "left/subject0.png".into() + )]) + ) + .await?, + vec![ + Packet::from([ + make_packet_key("subject".into(), "left/subject0.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject0.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + ], + "Unexpected streams." + ); + + assert_eq!( + next_batch( + operator, + (1..3) + .map(|i| { + ( + "left".into(), + Packet::from([make_packet_key( + "subject".into(), + format!("left/subject{i}.png"), + )]), + ) + }) + .collect::>() + ) + .await?, + vec![ + Packet::from([ + make_packet_key("subject".into(), "left/subject1.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject1.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject2.png".into()), + make_packet_key("style".into(), "right/style0.t7".into()), + ]), + Packet::from([ + make_packet_key("subject".into(), "left/subject2.png".into()), + make_packet_key("style".into(), "right/style1.t7".into()), + ]), + ], + "Unexpected streams." + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn map_once() -> Result<()> { + let operator = MapOperator::new(HashMap::from([("key_old".into(), "key_new".into())]))?; + + assert_eq!( + operator + .process_packet( + "parent".into(), + Packet::from([ + make_packet_key("key_old".into(), "some/key.txt".into()), + make_packet_key("subject".into(), "some/subject.txt".into()), + ]), + ) + .await?, + vec![Packet::from([ + make_packet_key("key_new".into(), "some/key.txt".into()), + make_packet_key("subject".into(), "some/subject.txt".into()), + ]),], + "Unexpected packet." + ); + + Ok(()) + } +} diff --git a/src/core/pipeline.rs b/src/core/pipeline.rs deleted file mode 100644 index 494e4919..00000000 --- a/src/core/pipeline.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::uniffi::model::pipeline::Kernel; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct PipelineNode { - pub name: String, - pub kernel: Kernel, -} diff --git a/src/core/pipeline_runner.rs b/src/core/pipeline_runner.rs new file mode 100644 index 00000000..4f836290 --- /dev/null +++ b/src/core/pipeline_runner.rs @@ -0,0 +1,905 @@ +use crate::{ + core::{ + crypto::hash_buffer, + model::{pipeline::PipelineNode, serialize_hashmap}, + operator::{JoinOperator, Operator}, + util::{get, make_key_expr}, + }, + uniffi::{ + error::{ + Kind, OrcaError, Result, + selector::{self}, + }, + model::{ + packet::{Packet, PathSet, URI}, + pipeline::{Kernel, PipelineJob, PipelineResult, PipelineStatus}, + pod::{Pod, PodJob, PodResult}, + }, + orchestrator::{ + PodStatus, + agent::{Agent, AgentClient, Response}, + }, + }, +}; +use async_trait::async_trait; +use names::{Generator, Name}; +use serde_yaml::Serializer; +use snafu::{OptionExt as _, ResultExt as _}; +use std::{ + collections::{BTreeMap, HashMap}, + path::PathBuf, + sync::Arc, +}; +use tokio::{ + sync::{Mutex, RwLock}, + task::JoinSet, +}; + +static NODE_OUTPUT_KEY_EXPR: &str = "output"; +static FAILURE_KEY_EXP: &str = "failure"; + +/// Internal representation of a pipeline run, which should not be made public due to the fact that it contains +#[derive(Debug)] +struct PipelineRunInternal { + /// `PipelineJob` that this run is associated with + assigned_name: String, + session: Arc, // Zenoh session for communication + agent_client: Arc, // Zenoh agent client for communication with docker orchestrators + pipeline_job: Arc, // The pipeline job that this run is associated with + node_tasks: Arc>>>, // JoinSet of tasks for each node in the pipeline + outputs: Arc>>>, // String is the node key, while hash + failure_logs: Arc>>, // Logs of processing failures + failure_logging_task: Arc>>>, // JoinSet of tasks for logging failures + namespace: String, + namespace_lookup: HashMap, +} + +impl PipelineRunInternal { + fn make_key_expr(&self, node_id: &str, event: &str) -> String { + make_key_expr( + &self.agent_client.group, + &self.agent_client.host, + "pipeline_run", + &BTreeMap::from([ + ("event".to_owned(), event.to_owned()), + ("node_id".to_owned(), node_id.to_owned()), + ("pipeline_run_id".to_owned(), self.assigned_name.clone()), + ]), + ) + } + + fn make_abort_request_key_exp(&self) -> String { + make_key_expr( + &self.agent_client.group, + &self.agent_client.host, + "pipeline_run", + &BTreeMap::from([ + ("event".to_owned(), "abort".to_owned()), + ("pipeline_run_id".to_owned(), self.assigned_name.clone()), + ]), + ) + } + + // Utils functions + async fn send_packets(&self, node_id: &str, output_packets: &Vec) -> Result<()> { + Ok(self + .session + .put( + self.make_key_expr(node_id, NODE_OUTPUT_KEY_EXPR), + serde_json::to_string(output_packets)?, + ) + .await + .context(selector::AgentCommunicationFailure {})?) + } + + async fn send_err_msg(&self, node_id: &str, err: OrcaError) { + self.session + .put( + &self.make_key_expr(node_id, FAILURE_KEY_EXP), + format!("Node {node_id}: {err}"), + ) + .await + .context(selector::AgentCommunicationFailure {}) + .unwrap_or_else(|send_err| { + eprintln!("Failed to send error message for node {node_id}: {send_err}"); + }); + } + + async fn send_abort_request(&self) -> Result<()> { + Ok(self + .session + .put(&self.make_abort_request_key_exp(), vec![]) + .await + .context(selector::AgentCommunicationFailure {})?) + } + + async fn get_status(&self) -> PipelineStatus { + if !self.node_tasks.lock().await.is_empty() { + PipelineStatus::Running + } else if self.outputs.read().await.is_empty() { + PipelineStatus::Failed + } else if self.failure_logs.read().await.is_empty() { + PipelineStatus::Succeeded + } else { + PipelineStatus::PartiallySucceeded + } + } +} + +/// Runner that uses a docker agent to run pipelines +#[derive(Debug, Clone)] +pub struct DockerPipelineRunner { + agent: Arc, + pipeline_runs: HashMap>, +} + +/// This is an implementation of a pipeline runner that uses Zenoh to communicate between the tasks +/// The runtime is tokio +/// +/// These are the key expressions of the components of the pipeline: +/// Input Node: `pipeline_job_hash/input_node/outputs` (This is where the `pipeline_job` packets get fed to) +/// Nodes: `pipeline_job_hash/node_id/outputs/(success|failure)` (This is where the node outputs are sent to) +/// +impl DockerPipelineRunner { + /// Create a new Docker pipeline runner + /// # Errors + /// Will error out if the environment variable `HOSTNAME` is not set + pub fn new(agent: Arc) -> Self { + Self { + agent, + pipeline_runs: HashMap::new(), + } + } + + /// Will start a new pipeline run with the given `PipelineJob` + /// This will start the async tasks for each node in the pipeline + /// including the one that captures the outputs from the leaf nodes + /// + /// Upon receiving the ready message from all the nodes, it will send the input packets to the input node + /// + /// # Errors + /// Will error out if the pipeline job fails to start + pub async fn start( + &mut self, + pipeline_job: PipelineJob, + namespace: &str, // Name space to save pod_results to + namespace_lookup: &HashMap, + ) -> Result { + // Create a new pipeline run + let pipeline_run = Arc::new(PipelineRunInternal { + pipeline_job: pipeline_job.into(), + outputs: Arc::new(RwLock::new(HashMap::new())), + node_tasks: Arc::new(Mutex::new(JoinSet::new())), + failure_logs: Arc::new(RwLock::new(Vec::new())), + failure_logging_task: Arc::new(Mutex::new(JoinSet::new())), + assigned_name: Generator::with_naming(Name::Plain).next().context( + selector::MissingInfo { + details: "unable to generate a random name", + }, + )?, + session: Arc::clone(&self.agent.client.session), + agent_client: Arc::clone(&self.agent.client), + namespace: namespace.to_owned(), + namespace_lookup: namespace_lookup.clone(), + }); + + // Create failure logging task + pipeline_run + .failure_logging_task + .lock() + .await + .spawn(Self::failure_capture_task(Arc::clone(&pipeline_run))); + + // Create the processor task for each node + // The id for the pipeline_run is the pipeline_job hash + let pipeline_run_id = pipeline_run.pipeline_job.hash.clone(); + + let graph = &pipeline_run.pipeline_job.pipeline.graph; + + // Create the subscriber that listen for ready messages + let subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.make_key_expr("*", "node_ready")) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Get the set of input_nodes + let input_nodes = pipeline_run.pipeline_job.pipeline.get_input_nodes(); + + // Iterate through each node in the graph and spawn a task for each + for node_idx in graph.node_indices() { + let node = &graph[node_idx]; + + // Spawn the task + pipeline_run + .node_tasks + .lock() + .await + .spawn(Self::spawn_node_processing_task( + graph[node_idx].clone(), + Arc::clone(&pipeline_run), + input_nodes.contains(&node.hash), + )); + } + + // Spawn the task that captures the outputs based on the output_spec + let mut node_output_spec = HashMap::new(); + // Group the output spec by node + for (output_key, node_uri) in &pipeline_run.pipeline_job.pipeline.output_spec { + node_output_spec + .entry(node_uri.node_id.clone()) + .or_insert_with(HashMap::new) + .insert(output_key.clone(), node_uri.key.clone()); + } + + for (node_id, key_mapping) in node_output_spec { + // Spawn the task that captures the outputs + pipeline_run + .node_tasks + .lock() + .await + .spawn(Self::create_output_capture_task_for_node( + key_mapping, + Arc::clone(&pipeline_run), + node_id.clone(), + )); + } + + // Wait for all nodes to be ready before sending inputs + let num_of_nodes = graph.node_count(); + let mut ready_nodes = 0; + + while (subscriber.recv_async().await).is_ok() { + // Message is empty, just increment the counter + ready_nodes += 1; + if ready_nodes == num_of_nodes { + break; // All nodes are ready, we can start sending inputs + } + } + + // For each node send all the packets associate with it + for (node_id, input_packets) in pipeline_run.pipeline_job.get_input_packet_per_node()? { + // Send the packet to the input node key_exp + pipeline_run + .send_packets(&format!("input_node_{node_id}"), &input_packets) + .await?; + + // Packets are sent, thus we can send the empty vec which signify processing is done + pipeline_run + .send_packets(&format!("input_node_{node_id}"), &Vec::new()) + .await?; + } + + // Insert into the list of pipeline runs + self.pipeline_runs + .insert(pipeline_run_id.clone(), pipeline_run); + + // Return the pipeline run id + Ok(pipeline_run_id) + } + + /// Given a pipeline run, wait for all its tasks to complete and return the `PipelineResult` + /// + /// # Errors + /// Will error out if any of the pipeline tasks failed to join + pub async fn get_result(&mut self, pipeline_run_id: &str) -> Result { + // To get the result, the pipeline execution must be complete, so we need to await on the tasks + + let pipeline_run = + self.pipeline_runs + .get_mut(pipeline_run_id) + .context(selector::KeyMissing { + key: pipeline_run_id.to_owned(), + })?; + + // Wait for all the tasks to complete + while let Some(result) = pipeline_run.node_tasks.lock().await.join_next().await { + result??; + } + + Ok(PipelineResult { + pipeline_job: Arc::clone(&pipeline_run.pipeline_job), + failure_logs: pipeline_run.failure_logs.read().await.clone(), + status: pipeline_run.get_status().await, + output_packets: pipeline_run.outputs.read().await.clone(), + }) + } + + /// Stop the pipeline run and all its tasks + /// This will send a stop message to a channel that all node manager task are subscribed to. + /// Upon receiving the stop message, each node manager will force abort all of its task and exit. + /// + /// # Errors + /// Will error out if the pipeline run is not found or if any of the tasks fail to stop correctly + pub async fn stop(&mut self, pipeline_run_id: &str) -> Result<()> { + // Get the pipeline run first then broadcast the abort request signal + let pipeline_run = + self.pipeline_runs + .get_mut(pipeline_run_id) + .context(selector::KeyMissing { + key: pipeline_run_id.to_owned(), + })?; + + // Send the abort request signal + pipeline_run.send_abort_request().await?; + + while pipeline_run + .node_tasks + .lock() + .await + .join_next() + .await + .is_some() + {} + Ok(()) + } + + /// This will capture the outputs of the given nodes and store it in the `outputs` map + async fn create_output_capture_task_for_node( + // + key_mapping: HashMap, + pipeline_run: Arc, + node_id: String, + ) -> Result<()> { + // Create a zenoh session + let subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.make_key_expr(&node_id, NODE_OUTPUT_KEY_EXPR)) + .await + .context(selector::AgentCommunicationFailure {})?; + + while let Ok(payload) = subscriber.recv_async().await { + // Extract the message from the payload + let packets: Vec = serde_json::from_slice(&payload.payload().to_bytes())?; + + if packets.is_empty() { + // Output node exited, thus we can exit the capture task too + break; + } + let mut outputs_lock = pipeline_run.outputs.write().await; + + for packet in packets { + for (output_key, node_key) in &key_mapping { + outputs_lock + .entry(output_key.to_owned()) + .or_default() + .push(get(&packet, node_key)?.clone()); + } + } + } + Ok(()) + } + + async fn failure_capture_task(pipeline_run: Arc) -> Result<()> { + let sub = pipeline_run + .session + .declare_subscriber(pipeline_run.make_key_expr("*", FAILURE_KEY_EXP)) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Listen to any failure messages and write it the logs + while let Ok(payload) = sub.recv_async().await { + // Extract the message from the payload + let failure_msg: String = serde_json::from_slice(&payload.payload().to_bytes())?; + // Store the failure message in the logs + pipeline_run.failure_logs.write().await.push(failure_msg); + } + + Ok(()) + } + + /// Function to start tasks associated with the node + /// Steps: + /// - Create the node processor based on the kernel type + /// - Create the zenoh session + /// - Create a join set to spawn and handle incoming messages tasks + /// - Create a subscriber for each of the parent nodes (Should only be 1, unless it is a joiner node) + /// - Create an abort listener task that will listen for stop requests + /// - For each subscriber, handle the incoming message appropriately + /// + /// # Errors + /// Will error out if the kernel for the node is not found or if the + async fn spawn_node_processing_task( + node: PipelineNode, + pipeline_run: Arc, + is_input_node: bool, + ) -> Result<()> { + // Get the node parents + let parent_nodes = pipeline_run + .pipeline_job + .pipeline + .get_node_parents(&node)? + .collect::>(); + + // Create the correct processor for the node based on the kernel type + let node_processor: Arc>> = + Arc::new(Mutex::new(match &node.kernel { + Kernel::Pod { pod } => Box::new(PodProcessor::new( + Arc::clone(&pipeline_run), + node.hash.clone(), + Arc::clone(pod), + )), + Kernel::MapOperator { mapper } => Box::new(OperatorProcessor::new( + Arc::clone(&pipeline_run), + node.hash.clone(), + Arc::clone(mapper), + parent_nodes.len(), + )), + Kernel::JoinOperator => Box::new(OperatorProcessor::new( + Arc::clone(&pipeline_run), + node.hash.clone(), + JoinOperator::new(parent_nodes.len()).into(), + parent_nodes.len(), + )), + })); + + // Create a join set to spawn and handle incoming messages tasks + let mut listener_tasks = JoinSet::new(); + + // Create a list of node_ids that this node should listen to + let mut nodes_to_sub_to = parent_nodes + .iter() + .map(|parent_node| parent_node.hash.clone()) + .collect::>(); + + if is_input_node { + // If the node is an input node, we need to add the input node key expression + nodes_to_sub_to.push(format!("input_node_{}", node.hash)); + } + + // For each node in nodes_to_subscribe_to, call the event handler func + for node_to_sub in &nodes_to_sub_to { + listener_tasks.spawn(Self::event_handler( + Arc::clone(&pipeline_run), + node.hash.clone(), + node_to_sub.to_owned(), + Arc::clone(&node_processor), + )); + } + + // Create the listener task for the stop request + let abort_request_handler_task = tokio::spawn(Self::abort_request_event_handler( + node_processor, + Arc::clone(&pipeline_run), + )); + + // Wait for all tasks to be spawned and reply with ready message + // This is to ensure that the pipeline run knows when all tasks are ready to receive inputs + let mut num_of_ready_event_handler: usize = 0; + // Build the subscriber + let status_subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.make_key_expr(&node.hash, "event_handler_ready")) + .await + .context(selector::AgentCommunicationFailure {})?; + + println!( + "Waiting for all event handlers for node {} to be ready... with hash {}", + node.label, node.hash + ); + while status_subscriber.recv_async().await.is_ok() { + num_of_ready_event_handler += 1; + if num_of_ready_event_handler == nodes_to_sub_to.len() { + // +1 for the stop request task + break; // All tasks are ready, we can start sending inputs + } + } + + println!("Node {} is ready with hash {}", node.label, node.hash); + // Send a ready message so the pipeline knows when to start sending inputs + pipeline_run + .session + .put(pipeline_run.make_key_expr(&node.hash, "node_ready"), vec![]) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Wait for all task to complete + while let Some(result) = listener_tasks.join_next().await { + match result { + Ok(Ok(())) => {} // Task completed successfully + Ok(Err(err)) => { + pipeline_run.send_err_msg(&node.hash, err).await; + } + Err(err) => { + pipeline_run + .send_err_msg(&node.hash, OrcaError::from(err)) + .await; + } + } + } + + // Abort the stop listener task since we don't need it anymore + abort_request_handler_task.abort(); + + Ok(()) + } + + /// This is the actual handler for incoming messages for the node + async fn event_handler( + pipeline_run: Arc, + node_id: String, + node_to_sub_to: String, + processor: Arc>>, + ) -> Result<()> { + // Create the subscriber + let subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.make_key_expr(&node_to_sub_to, NODE_OUTPUT_KEY_EXPR)) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Send out ready signal + pipeline_run + .session + .put( + pipeline_run.make_key_expr(&node_id, "event_handler_ready"), + vec![], + ) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Listen to the key + loop { + let sample = subscriber + .recv_async() + .await + .context(selector::AgentCommunicationFailure)?; + + // Extract out the packets + let packets: Vec = serde_json::from_slice(&sample.payload().to_bytes())?; + + // Check if the packets are empty, if so that means the node is finished processing + if packets.is_empty() { + processor + .lock() + .await + .mark_parent_as_complete(&node_to_sub_to) + .await; + break; + } + + // For each packet, we need to process it + for packet in packets { + processor + .lock() + .await + .process_incoming_packet(&node_to_sub_to, &packet) + .await; + } + } + Ok::<(), OrcaError>(()) + } + + /// This task will listen for stop requests on the given key expression + async fn abort_request_event_handler( + node_processor: Arc>>, + pipeline_run: Arc, + ) -> Result<()> { + let subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.make_abort_request_key_exp()) + .await + .context(selector::AgentCommunicationFailure {})?; + while subscriber.recv_async().await.is_ok() { + // Received a request to stop, therefore we need to tell the node_processor to shutdown + node_processor.lock().await.stop(); + } + Ok::<(), OrcaError>(()) + } +} + +/// Unify the interface for node processors and provide a common way to handle processing of incoming messages +/// This trait defines the methods that all node processors should implement +/// +/// Main purpose was to reduce the amount of code duplication between different node processors +/// As a result, each processor only needs to worry about writing their own function to process the msg +#[async_trait] +trait NodeProcessor: Send + Sync { + async fn process_incoming_packet(&mut self, sender_node_hash: &str, incoming_packet: &Packet); + + /// Notifies the processor that the parent node has completed processing + /// If it is the last parent to complete, it will wait for all processing task to finish + /// Then send a completion signal + async fn mark_parent_as_complete(&mut self, parent_node_hash: &str); + + fn stop(&mut self); +} + +/// Processor for Pods +/// Currently missing implementation to call agents for actual pod processing +struct PodProcessor { + pipeline_run: Arc, + node_hash: String, + pod: Arc, + processing_tasks: JoinSet<()>, +} + +impl PodProcessor { + fn new(pipeline_run: Arc, node_hash: String, pod: Arc) -> Self { + Self { + pipeline_run, + node_hash, + pod, + processing_tasks: JoinSet::new(), + } + } +} + +impl PodProcessor { + /// Will handle the creation of the pod job, submission to the agent, listening for completion, and extracting the `output_packet` if successful + async fn process_packet( + pipeline_run: Arc, + node_hash: String, + pod: Arc, + incoming_packet: HashMap, + ) -> Result { + // Hash the input_packet to create a unique identifier for the pod job + let input_packet_hash = { + let mut buf = Vec::new(); + let mut serializer = Serializer::new(&mut buf); + serialize_hashmap(&incoming_packet, &mut serializer)?; + hash_buffer(buf) + }; + + // Create the pod job + let pod_job = PodJob::new( + None, + Arc::clone(&pod), + incoming_packet, + URI { + namespace: pipeline_run.namespace.clone(), + path: format!( + "pipeline_outputs/{}/{node_hash}/{input_packet_hash}", + pipeline_run.assigned_name + ) + .into(), + }, + pod.recommend_specs.cpus, + pod.recommend_specs.memory, + None, + &pipeline_run.namespace_lookup, + )?; + + // Create listener for pod_job + // Create the subscriber + let pod_job_subscriber = pipeline_run + .session + .declare_subscriber(pipeline_run.agent_client.make_key_expr( + true, + "pod_job", + BTreeMap::from([("hash", pod_job.hash.clone()), ("event", "*".to_owned())]), + )) + .await + .context(selector::AgentCommunicationFailure {})?; + + // Create the async task to listen for the pod job completion + let pod_job_listener_task = tokio::spawn(async move { + // Wait for the pod job to complete and extract the result + let sample = pod_job_subscriber + .recv_async() + .await + .context(selector::AgentCommunicationFailure {})?; + // Extract the pod_result from the payload + let pod_result: PodResult = serde_json::from_slice(&sample.payload().to_bytes())?; + Ok::<_, OrcaError>(pod_result) + }); + + // Submit it to the client and get the response to make sure it was successful + let responses = pipeline_run + .agent_client + .start_pod_jobs(vec![pod_job.clone().into()]) + .await; + let response = responses + .first() + .context(selector::InvalidIndex { idx: 0_usize })?; + + match response { + Response::Ok => (), + Response::Err(err) => { + return Err(OrcaError { + kind: Kind::PodJobProcessingError { + hash: pod_job.hash, + reason: err.clone(), + backtrace: Some(snafu::Backtrace::capture()), + }, + }); + } + } + + // Get the pod result from the listener task + let pod_result = pod_job_listener_task.await??; + // Get the output packet for the pod result + Ok(match pod_result.status { + PodStatus::Completed => { + // Get the output packet + pod_result.output_packet + } + PodStatus::Failed(exit_code) => { + // Processing failed, thus return the error + return Err(OrcaError { + kind: Kind::PodJobProcessingError { + hash: pod_result.pod_job.hash.clone(), + reason: format!("Pod processing failed with exit code {exit_code}"), + backtrace: Some(snafu::Backtrace::capture()), + }, + }); + } + PodStatus::Running | PodStatus::Unset | PodStatus::Undefined => { + // This should not happen, but if it does, we will return an error + return Err(OrcaError { + kind: Kind::PodJobProcessingError { + hash: pod_result.pod_job.hash.clone(), + reason: "Pod result status is running or unset".to_owned(), + backtrace: Some(snafu::Backtrace::capture()), + }, + }); + } + }) + } +} + +#[async_trait] +impl NodeProcessor for PodProcessor { + async fn process_incoming_packet( + &mut self, + _sender_node_hash: &str, + incoming_packet: &HashMap, + ) { + // Clone all necessary fields from self to move into the async block + let pipeline_run = Arc::clone(&self.pipeline_run); + let node_hash = self.node_hash.clone(); + let pod = Arc::clone(&self.pod); + + let incoming_packet_inner = incoming_packet.clone(); + + self.processing_tasks.spawn(async move { + let result = match Self::process_packet( + Arc::clone(&pipeline_run), + node_hash.clone(), + Arc::clone(&pod), + incoming_packet_inner.clone(), + ) + .await + { + Ok(output_packet) => { + match pipeline_run + .send_packets(&node_hash, &vec![output_packet]) + .await + { + Ok(()) => Ok(()), + Err(err) => Err(err), + } + } + Err(err) => Err(err), + }; + + match result { + Ok(()) => { + // Successfully processed the packet, nothing to do + } + Err(err) => { + pipeline_run.send_err_msg(&node_hash, err).await; + } + } + }); + } + + async fn mark_parent_as_complete(&mut self, _parent_node_id: &str) { + // For pod we only have one parent, thus execute the exit case + while self.processing_tasks.join_next().await.is_some() {} + // Send out completion signal + match self + .pipeline_run + .send_packets(&self.node_hash, &Vec::new()) + .await + { + Ok(()) => {} + Err(err) => { + self.pipeline_run.send_err_msg(&self.node_hash, err).await; + } + } + } + + fn stop(&mut self) { + self.processing_tasks.abort_all(); + } +} + +struct OperatorProcessor { + pipeline_run: Arc, + node_id: String, + operator: Arc, + num_of_parents: usize, + num_of_completed_parents: usize, + processing_tasks: JoinSet<()>, +} + +impl OperatorProcessor { + /// Create a new operator processor + pub fn new( + pipeline_run: Arc, + node_id: String, + operator: Arc, + num_of_parents: usize, + ) -> Self { + Self { + pipeline_run, + node_id, + operator, + num_of_parents, + num_of_completed_parents: 0, + processing_tasks: JoinSet::new(), + } + } +} + +#[allow( + clippy::excessive_nesting, + reason = "Nesting manageable and mute github action error" +)] +#[async_trait] +impl NodeProcessor for OperatorProcessor { + async fn process_incoming_packet( + &mut self, + sender_node_hash: &str, + incoming_packet: &HashMap, + ) { + // Clone all necessary fields from self to move into the async block + let operator = Arc::clone(&self.operator); + let pipeline_run = Arc::clone(&self.pipeline_run); + let node_id = self.node_id.clone(); + + let sender_node_id_inner = sender_node_hash.to_owned(); + let incoming_packet_inner = incoming_packet.clone(); + + self.processing_tasks.spawn(async move { + let processing_result = operator + .process_packet(sender_node_id_inner, incoming_packet_inner) + .await; + + match processing_result { + Ok(output_packets) => { + if !output_packets.is_empty() { + // Send out all the packets + match pipeline_run.send_packets(&node_id, &output_packets).await { + Ok(()) => {} + Err(err) => { + pipeline_run.send_err_msg(&node_id, err).await; + } + } + } + } + Err(err) => { + pipeline_run.send_err_msg(&node_id, err).await; + } + } + }); + } + + async fn mark_parent_as_complete(&mut self, _parent_node_hash: &str) { + // Figure out if this is the last parent or not + self.num_of_completed_parents += 1; + + if self.num_of_completed_parents == self.num_of_parents { + // All parents are complete, thus we need to wait on all processing tasks then exit + while (self.processing_tasks.join_next().await).is_some() { + // Wait for all tasks to complete + } + // Send out completion signal which is same as success but it is an empty vec of packets + match self + .pipeline_run + .send_packets(&self.node_id, &Vec::new()) + .await + { + Ok(()) => {} + Err(err) => { + self.pipeline_run.send_err_msg(&self.node_id, err).await; + } + } + } + } + + fn stop(&mut self) { + self.processing_tasks.abort_all(); + } +} diff --git a/src/core/store/filestore.rs b/src/core/store/filestore.rs index 622a3a82..c21d2254 100644 --- a/src/core/store/filestore.rs +++ b/src/core/store/filestore.rs @@ -1,12 +1,8 @@ use crate::{ - core::{ - model::to_yaml, - store::MODEL_NAMESPACE, - util::{get_type_name, parse_debug_name}, - }, + core::{model::ToYaml, store::MODEL_NAMESPACE, util::get_type_name}, uniffi::{ - error::{Result, selector}, - model::Annotation, + error::{OrcaError, Result, selector}, + model::{Annotation, pipeline::Pipeline}, store::{ModelID, ModelInfo, filestore::LocalFileStore}, }, }; @@ -16,7 +12,7 @@ use heck::ToSnakeCase as _; use regex::Regex; use serde::{Serialize, de::DeserializeOwned}; use serde_yaml; -use snafu::{OptionExt as _, ResultExt as _}; +use snafu::OptionExt as _; use std::{ fmt, fs, path::{Path, PathBuf}, @@ -55,17 +51,12 @@ impl LocalFileStore { PathBuf::from(format!("annotation/{name}-{version}.yaml")) } /// Build the storage path with the model directory (`hash`) and a file's relative path. - pub fn make_path( - &self, - model: &T, - hash: &str, - relpath: impl AsRef, - ) -> PathBuf { + pub fn make_path(&self, hash: &str, relpath: impl AsRef) -> PathBuf { PathBuf::from(format!( "{}/{}/{}/{}", self.directory.to_string_lossy(), MODEL_NAMESPACE, - parse_debug_name(model).to_snake_case(), + get_type_name::().to_snake_case(), hash )) .join(relpath) @@ -90,28 +81,21 @@ impl LocalFileStore { /// # Errors /// /// Will return error if unable to find. - pub(crate) fn lookup_hash( - &self, - model: &T, - name: &str, - version: &str, - ) -> Result { - let model_info = Self::find_model_metadata(&self.make_path( - model, - "*", - Self::make_annotation_relpath(name, version), - ))? + pub(crate) fn lookup_hash(&self, name: &str, version: &str) -> Result { + let model_info = Self::find_model_metadata( + &self.make_path::("*", Self::make_annotation_relpath(name, version)), + )? .next() .context(selector::MissingInfo { details: format!( "annotation where class = {}, name = {name}, version = {version}", - parse_debug_name(model).to_snake_case() + get_type_name::().to_snake_case() ), })?; Ok(model_info.hash) } - fn save_file(file: impl AsRef, content: impl AsRef<[u8]>) -> Result<()> { + pub(crate) fn save_file(file: impl AsRef, content: impl AsRef<[u8]>) -> Result<()> { if let Some(parent) = file.as_ref().parent() { fs::create_dir_all(parent)?; } @@ -123,7 +107,7 @@ impl LocalFileStore { /// # Errors /// /// Will return `Err` if there is an issue storing the model. - pub(crate) fn save_model( + pub(crate) fn save_model( &self, model: &T, hash: &str, @@ -137,7 +121,7 @@ impl LocalFileStore { &provided_annotation.version, ); if let Some((found_hash, found_name, found_version)) = - Self::find_model_metadata(&self.make_path(model, "*", relpath))? + Self::find_model_metadata(&self.make_path::("*", relpath))? .next() .and_then(|model_info| { Some((model_info.hash, model_info.name?, model_info.version?)) @@ -156,13 +140,13 @@ impl LocalFileStore { ); } else { Self::save_file( - self.make_path(model, hash, relpath), + self.make_path::(hash, relpath), serde_yaml::to_string(provided_annotation)?, )?; } } // Save model specification and skip if it already exist e.g. on new annotations - let spec_file = &self.make_path(model, hash, Self::SPEC_RELPATH); + let spec_file = &self.make_path::(hash, Self::SPEC_RELPATH); if spec_file.exists() { println!( "{}", @@ -174,7 +158,7 @@ impl LocalFileStore { .yellow(), ); } else { - Self::save_file(spec_file, to_yaml(model)?)?; + Self::save_file(spec_file, model.to_yaml()?)?; } Ok(()) } @@ -188,33 +172,30 @@ impl LocalFileStore { &self, model_id: &ModelID, ) -> Result<(T, Option, String)> { + let (hash, annotation) = self.decode_model_id::(model_id)?; + + Ok(( + serde_yaml::from_str(&fs::read_to_string( + self.make_path::(&hash, Self::SPEC_RELPATH), + )?)?, + annotation, + hash, + )) + } + + pub(crate) fn decode_model_id( + &self, + model_id: &ModelID, + ) -> Result<(String, Option)> { match model_id { - ModelID::Hash(hash) => { - let path = self.make_path(&T::default(), hash, Self::SPEC_RELPATH); - Ok(( - serde_yaml::from_str( - &fs::read_to_string(path.clone()) - .context(selector::InvalidPath { path })?, - )?, - None, - hash.to_owned(), - )) - } + ModelID::Hash(hash) => Ok((hash.to_owned(), None)), ModelID::Annotation(name, version) => { - let hash = self.lookup_hash(&T::default(), name, version)?; - Ok(( - serde_yaml::from_str(&fs::read_to_string(self.make_path( - &T::default(), - &hash, - Self::SPEC_RELPATH, - ))?)?, - serde_yaml::from_str(&fs::read_to_string(self.make_path( - &T::default(), - &hash, - Self::make_annotation_relpath(name, version), - ))?)?, - hash, - )) + let hash = self.lookup_hash::(name, version)?; + let annotation_str = fs::read_to_string( + self.make_path::(&hash, Self::make_annotation_relpath(name, version)), + )?; + let annotation: Annotation = serde_yaml::from_str(&annotation_str)?; + Ok((hash, Some(annotation))) } } } @@ -224,7 +205,7 @@ impl LocalFileStore { /// /// Will return `Err` if there is an issue querying metadata from existing models in the store. pub(crate) fn list_model(&self) -> Result> { - Ok(Self::find_model_metadata(&self.make_path(&T::default(), "**", "*"))?.collect()) + Ok(Self::find_model_metadata(&self.make_path::("**", "*"))?.collect()) } /// How to explicitly delete any stored model and all associated annotations (does not propagate). /// @@ -236,13 +217,32 @@ impl LocalFileStore { // assumes propagate = false let hash = match model_id { ModelID::Hash(hash) => hash, - ModelID::Annotation(name, version) => { - &self.lookup_hash(&T::default(), name, version)? - } + ModelID::Annotation(name, version) => &self.lookup_hash::(name, version)?, }; - let spec_dir = self.make_path(&T::default(), hash, ""); + let spec_dir = self.make_path::(hash, ""); fs::remove_dir_all(spec_dir)?; Ok(()) } + + pub(crate) fn get_latest_pipeline_labels_file_name( + &self, + pipeline_hash: &str, + ) -> Result> { + let existing_labels_path = self.make_path::(pipeline_hash, "labels/"); + Ok(if existing_labels_path.exists() { + let mut label_file_names = fs::read_dir(&existing_labels_path)? + .map(|entry| Ok::<_, OrcaError>(entry?.file_name())) + .collect::, _>>()?; + + // Sort and get the latest one + label_file_names.sort(); + + label_file_names + .last() + .map(|os_str| os_str.to_string_lossy().to_string()) + } else { + None + }) + } } diff --git a/src/core/util.rs b/src/core/util.rs index f41e4356..8b033c80 100644 --- a/src/core/util.rs +++ b/src/core/util.rs @@ -1,29 +1,19 @@ use crate::uniffi::error::{Result, selector}; +use heck::ToSnakeCase as _; use snafu::OptionExt as _; -use std::{any::type_name, borrow::Borrow, collections::HashMap, fmt, hash}; - -#[expect( - clippy::unwrap_used, - reason = "Cannot return `None` since `type_name` always returns `&str`." -)] -pub fn get_type_name() -> String { - type_name::() - .split("::") - .map(str::to_owned) - .last() - .unwrap() -} +use std::{ + any::type_name, + borrow::Borrow, + collections::{BTreeMap, HashMap}, + fmt, hash, +}; #[expect( clippy::unwrap_used, reason = "Cannot return `None` since debug format always returns `String`." )] -pub fn parse_debug_name(instance: &T) -> String { - format!("{instance:?}") - .split(' ') - .map(str::to_owned) - .next() - .unwrap() +pub fn get_type_name() -> String { + type_name::().split("::").last().unwrap().to_snake_case() } pub fn get<'map, K, V, Q>(map: &'map HashMap, key: &Q) -> Result<&'map V> @@ -35,3 +25,19 @@ where details: format!("key = {key:?}"), })?) } + +pub fn make_key_expr( + group: &str, + host: &str, + topic: &str, + content: &BTreeMap, +) -> String { + // For each key-value pair in the content, we format it as "key/value" and join them with "/". + // The final format will be "group/host/topic/key1/value1/key2/value + let content_converted = content + .iter() + .map(|(k, v)| format!("{k}/{v}")) + .collect::>() + .join("/"); + format!("{group}/{host}/{topic}/{content_converted}") +} diff --git a/src/uniffi/error.rs b/src/uniffi/error.rs index 00ed5c2e..5cc2e5db 100644 --- a/src/uniffi/error.rs +++ b/src/uniffi/error.rs @@ -11,6 +11,7 @@ use serde_yaml; use snafu::prelude::Snafu; use std::{ backtrace::Backtrace, + collections::HashSet, error::Error, io, path::{self, PathBuf}, @@ -18,6 +19,7 @@ use std::{ }; use tokio::task; use uniffi; + /// Shorthand for a Result that returns an [`OrcaError`]. pub type Result = result::Result; /// Possible errors you may encounter. @@ -30,6 +32,35 @@ pub(crate) enum Kind { source: Box, backtrace: Option, }, + #[snafu(display("Empty directory: {dir:?}, where they should be files"))] + EmptyDir { + dir: PathBuf, + backtrace: Option, + }, + #[snafu(display( + "Failed to extract run info from the container image file: {container_name}." + ))] + FailedToExtractRunInfo { + container_name: String, + backtrace: Option, + }, + #[snafu(display( + "Missing expected output file or dir with key {packet_key} at path {path:?} for pod job (hash: {pod_job_hash})." + ))] + FailedToGetPodJobOutput { + pod_job_hash: String, + packet_key: String, + path: Box, + io_error: Box, + backtrace: Option, + }, + #[snafu(display( + "Failed to get label hash from file name: {file_name}. Split result by \"-\" didn't return hash" + ))] + FailedToGetLabelHashFromFileName { + file_name: String, + backtrace: Option, + }, #[snafu(display("Incomplete {kind} packet. Missing `{missing_keys:?}` keys."))] IncompletePacket { kind: String, @@ -50,11 +81,67 @@ pub(crate) enum Kind { source: io::Error, backtrace: Option, }, + #[snafu(display("Failed to get items at idx {idx}."))] + InvalidIndex { + idx: usize, + backtrace: Option, + }, + + #[snafu(display("Key '{key}' was not found in map."))] + KeyMissing { + key: String, + backtrace: Option, + }, #[snafu(display("Missing info. Details: {details}."))] MissingInfo { details: String, backtrace: Option, }, + #[snafu(display( + "Node '{node_name}' was referenced in input_spec, but is not a node in the graph." + ))] + InvalidInputSpecNodeNotInGraph { + node_name: String, + backtrace: Option, + }, + #[snafu(display( + "Key '{key}' was referenced in input_spec for node '{node_name}', but is not a key in that node's input spec." + ))] + InvalidOutputSpecKeyNotInNode { + node_name: String, + key: String, + backtrace: Option, + }, + #[snafu(display( + "Node '{node_name}' was referenced in output_spec, but is not a node in the graph." + ))] + InvalidOutputSpecNodeNotInGraph { + node_name: String, + backtrace: Option, + }, + #[snafu(display("Node '{node_name}' is missing required keys: {missing_keys:?}."))] + PipelineValidationErrorMissingKeys { + node_name: String, + missing_keys: HashSet, + backtrace: Option, + }, + + #[snafu(display("Pod job submission failed with reason: {reason}."))] + PodJobSubmissionFailed { + reason: String, + backtrace: Option, + }, + #[snafu(display("Pod job {hash} failed to process with reason: {reason}."))] + PodJobProcessingError { + hash: String, + reason: String, + backtrace: Option, + }, + #[snafu(display("Unexpected path type: {path:?}. Only support files and directories."))] + UnexpectedPathType { + path: PathBuf, + backtrace: Option, + }, #[snafu(transparent)] BollardError { source: Box, diff --git a/src/uniffi/mod.rs b/src/uniffi/mod.rs index e02fd6c9..2443ce00 100644 --- a/src/uniffi/mod.rs +++ b/src/uniffi/mod.rs @@ -2,6 +2,8 @@ pub mod error; /// Components of the data model. pub mod model; +/// Operators for pipeline +pub mod operator; /// Interface into container orchestration engine. pub mod orchestrator; /// Data persistence provided by a store backend. diff --git a/src/uniffi/model/packet.rs b/src/uniffi/model/packet.rs index 595a3f92..c1f9c013 100644 --- a/src/uniffi/model/packet.rs +++ b/src/uniffi/model/packet.rs @@ -2,6 +2,9 @@ use serde::{Deserialize, Serialize}; use std::{collections::HashMap, path::PathBuf}; use uniffi; +use crate::core::util::get; +use crate::uniffi::error::Result; + /// Path sets are named and represent an abstraction for the file(s) that represent some particular /// data within a compute environment. #[derive(uniffi::Record, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] @@ -12,6 +15,9 @@ pub struct PathInfo { pub match_pattern: String, } +#[uniffi::export] +impl PathInfo {} + /// File or directory options for BLOBs. #[derive(uniffi::Enum, Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] pub enum BlobKind { @@ -39,9 +45,21 @@ pub struct Blob { /// BLOB location. pub location: URI, /// BLOB contents checksum. + #[uniffi(default = "")] pub checksum: String, } +impl Blob { + /// Create a new `Blob` + pub const fn new(kind: BlobKind, location: URI) -> Self { + Self { + kind, + location, + checksum: String::new(), + } + } +} + /// A single BLOB or a collection of BLOBs. #[derive(uniffi::Enum, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(untagged)] @@ -52,5 +70,28 @@ pub enum PathSet { Collection(Vec), } +impl PathSet { + /// Util function to convert ``PathSet`` to ``PathBuf`` given a namespace lookup table + /// + /// # Errors + /// Will error out if namespace is missing in namespace lookup + pub fn to_path_buf(&self, namespace_lookup: &HashMap) -> Result> { + match self { + Self::Unary(blob) => { + let base_path = get(namespace_lookup, &blob.location.namespace)?; + Ok(vec![base_path.join(&blob.location.path)]) + } + Self::Collection(blobs) => { + let mut paths = Vec::with_capacity(blobs.len()); + for blob in blobs { + let base_path = get(namespace_lookup, &blob.location.namespace)?; + paths.push(base_path.join(&blob.location.path)); + } + Ok(paths) + } + } + } +} + /// A complete set of inputs to be provided to a computational unit. pub type Packet = HashMap; diff --git a/src/uniffi/model/pipeline.rs b/src/uniffi/model/pipeline.rs index 223f73cf..07788648 100644 --- a/src/uniffi/model/pipeline.rs +++ b/src/uniffi/model/pipeline.rs @@ -1,38 +1,57 @@ use crate::{ core::{ - crypto::{hash_blob, make_random_hash}, + crypto::{hash_blob, hash_buffer, make_random_hash}, graph::make_graph, - pipeline::PipelineNode, + model::{ToYaml as _, pipeline::PipelineNode}, validation::validate_packet, }, uniffi::{ - error::Result, + error::{OrcaError, Result, selector}, model::{ + Annotation, packet::{PathSet, URI}, pod::Pod, }, + operator::MapOperator, }, }; use derive_more::Display; use getset::CloneGetters; use petgraph::graph::DiGraph; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use snafu::OptionExt as _; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + path::PathBuf, + sync::Arc, +}; +use std::{hash::Hasher, sync::LazyLock}; use uniffi; +pub(crate) static JOIN_OPERATOR_HASH: LazyLock = + LazyLock::new(|| hash_buffer(b"join_operator")); + /// Computational dependencies as a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph). -#[derive(uniffi::Object, Debug, Display, CloneGetters, Clone, Deserialize, Serialize)] +#[derive(uniffi::Object, Debug, Display, CloneGetters, Clone, Deserialize, Default)] #[getset(get_clone, impl_attrs = "#[uniffi::export]")] #[display("{self:#?}")] #[uniffi::export(Display)] pub struct Pipeline { + /// Hash for pipeline + #[serde(default)] + pub hash: String, + /// Annotations for the pipeline. + #[serde(default)] + pub annotation: Option, /// Computational DAG in-memory. #[getset(skip)] + #[serde(skip_deserializing)] pub graph: DiGraph, /// Exposed, internal input specification. Each input may be fed into more than one node/key if desired. - pub input_spec: HashMap>, + pub input_spec: HashMap>, /// Exposed, internal output specification. Each output is associated with only one node/key. - pub output_spec: HashMap, + pub output_spec: HashMap, } #[uniffi::export] @@ -45,19 +64,84 @@ impl Pipeline { #[uniffi::constructor] pub fn new( graph_dot: &str, - metadata: HashMap, - input_spec: &HashMap>, - output_spec: &HashMap, + metadata: &HashMap, + mut input_spec: HashMap>, + mut output_spec: HashMap, + annotation: Option, ) -> Result { - let graph = make_graph(graph_dot, metadata)?; - Ok(Self { + // Note this gives us the graph, but the nodes do not have their hashes computed yet. + let mut graph = make_graph(graph_dot, metadata)?; + + // Run preprocessing to compute the hash for each node + for node_idx in graph.node_indices() { + Self::compute_hash_for_node_and_parents(node_idx, &input_spec, &mut graph); + } + + // Build LUT for node_label -> node_hash + let label_to_hash_lut = + graph + .node_indices() + .fold(HashMap::<&String, &String>::new(), |mut acc, node_idx| { + let node = &graph[node_idx]; + acc.insert(&node.label, &node.hash); + acc + }); + + // Build the new input_spec to refer to the hash instead of label + input_spec.iter_mut().try_for_each(|(_, node_uris)| { + node_uris.iter_mut().try_for_each(|node_uri| { + node_uri.node_id = (*label_to_hash_lut.get(&node_uri.node_id).context( + selector::InvalidInputSpecNodeNotInGraph { + node_name: node_uri.node_id.clone(), + }, + )?) + .clone(); + Ok::<(), OrcaError>(()) + }) + })?; + + // Update the output_spec to refer to the hash instead of label + output_spec.iter_mut().try_for_each(|(_, node_uri)| { + node_uri.node_id = (*label_to_hash_lut.get(&node_uri.node_id).context( + selector::InvalidOutputSpecNodeNotInGraph { + node_name: node_uri.node_id.clone(), + }, + )?) + .clone(); + + Ok::<(), OrcaError>(()) + })?; + + let pipeline_no_hash = Self { + hash: String::new(), graph, - input_spec: input_spec.clone(), - output_spec: output_spec.clone(), + input_spec, + output_spec, + annotation, + }; + + // Run verification on the pipeline first before computing hash + pipeline_no_hash.validate()?; + + Ok(Self { + hash: hash_buffer(pipeline_no_hash.to_yaml()?.as_bytes()), + ..pipeline_no_hash }) } } +impl PartialEq for Pipeline { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash + && self.annotation == other.annotation + && self.input_spec.keys().collect::>() + == other.input_spec.keys().collect::>() + && self.input_spec.values().collect::>() + == other.input_spec.values().collect::>() + && self.output_spec == other.output_spec + } +} + /// A compute pipeline job that supplies input/output targets. #[expect( clippy::field_scoped_visibility_modifiers, @@ -79,7 +163,6 @@ pub struct PipelineJob { pub output_dir: URI, } -#[expect(clippy::excessive_nesting, reason = "Nesting manageable.")] #[uniffi::export] impl PipelineJob { /// Construct a new pipeline job instance. @@ -91,7 +174,7 @@ impl PipelineJob { pub fn new( pipeline: Arc, input_packet: &HashMap>, - output_dir: &URI, + output_dir: URI, namespace_lookup: &HashMap, ) -> Result { validate_packet("input".into(), &pipeline.input_spec, input_packet)?; @@ -124,33 +207,109 @@ impl PipelineJob { hash: make_random_hash(), pipeline, input_packet: input_packet_with_checksum, - output_dir: output_dir.clone(), + output_dir, }) } } +/// Struct to hold the result of a pipeline execution. +#[derive(uniffi::Object, Debug, Clone, Deserialize, Serialize, Display, CloneGetters)] +#[getset(get_clone, impl_attrs = "#[uniffi::export]")] +#[display("{self:#?}")] +#[uniffi::export(Display)] +pub struct PipelineResult { + /// The pipeline job that was executed. + pub pipeline_job: Arc, + /// The result of the pipeline execution. + pub output_packets: HashMap>, + /// Logs of any failures that occurred during the pipeline execution. + pub failure_logs: Vec, + /// The status of the pipeline execution. + pub status: PipelineStatus, +} + +/// The status of a pipeline execution. +#[derive(uniffi::Enum, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub enum PipelineStatus { + /// The pipeline is currently running. + Running, + /// The pipeline has completed successfully. + Succeeded, + /// The pipeline has failed. + Failed, + /// The pipeline has partially succeeded. There should be some failure logs + PartiallySucceeded, +} /// A node in a computational pipeline. #[derive(uniffi::Enum, Debug, Clone, Deserialize, Serialize)] pub enum Kernel { /// Pod reference. Pod { /// See [`Pod`](crate::uniffi::model::pod::Pod). - r#ref: Arc, + pod: Arc, }, /// Cartesian product operation. See [`JoinOperator`](crate::core::operator::JoinOperator). JoinOperator, /// Rename a path set key operation. MapOperator { /// See [`MapOperator`](crate::core::operator::MapOperator). - map: HashMap, + mapper: Arc, }, } +impl From for Kernel { + fn from(mapper: MapOperator) -> Self { + Self::MapOperator { + mapper: Arc::new(mapper), + } + } +} + +impl From for Kernel { + fn from(pod: Pod) -> Self { + Self::Pod { pod: Arc::new(pod) } + } +} + +impl From> for Kernel { + fn from(pod: Arc) -> Self { + Self::Pod { pod } + } +} + +impl Kernel { + /// Get a unique hash that represents the kernel. + /// The exception here is the `JoinOperator` doesn't have any pre execution configuration, since it's logic is completely dependent on what is fed to it during execution. + pub fn get_hash(&self) -> &str { + match self { + Self::Pod { pod } => &pod.hash, + Self::JoinOperator => &JOIN_OPERATOR_HASH, + Self::MapOperator { mapper } => &mapper.hash, + } + } +} + +impl PartialEq for Kernel { + fn eq(&self, other: &Self) -> bool { + self.get_hash() == other.get_hash() + } +} + +impl Eq for Kernel {} + +impl Hash for Kernel { + fn hash(&self, state: &mut H) { + self.get_hash().hash(state); + } +} + /// Index from pipeline node into pod specification. -#[derive(uniffi::Record, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] -pub struct SpecURI { +#[derive( + uniffi::Record, Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +pub struct NodeURI { /// Node reference name in pipeline. - pub node: String, + pub node_id: String, /// Specification key. pub key: String, } diff --git a/src/uniffi/model/pod.rs b/src/uniffi/model/pod.rs index 62cd006f..13ae9f1e 100644 --- a/src/uniffi/model/pod.rs +++ b/src/uniffi/model/pod.rs @@ -2,14 +2,15 @@ use crate::{ core::{ crypto::{hash_blob, hash_buffer}, model::{ + ToYaml, pod::{deserialize_pod, deserialize_pod_job}, - serialize_hashmap, serialize_hashmap_option, to_yaml, + serialize_hashmap, serialize_hashmap_option, }, util::get, validation::validate_packet, }, uniffi::{ - error::{OrcaError, Result}, + error::{Kind, OrcaError, Result}, model::{ Annotation, packet::{Blob, BlobKind, Packet, PathInfo, PathSet, URI}, @@ -20,7 +21,7 @@ use crate::{ use derive_more::Display; use getset::CloneGetters; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, path::PathBuf, sync::Arc}; +use std::{backtrace::Backtrace, collections::HashMap, path::PathBuf, sync::Arc}; use uniffi; /// A reusable, containerized computational unit. @@ -49,14 +50,11 @@ pub struct Pod { /// Exposed, internal output specification. #[serde(serialize_with = "serialize_hashmap")] pub output_spec: HashMap, - /// Link to source associated with image binary. - pub source_commit_url: String, - /// Recommendation for CPU in fractional cores. - pub recommended_cpus: f32, - /// Recommendation for memory in bytes. - pub recommended_memory: u64, - /// If applicable, recommendation for GPU configuration. - pub required_gpu: Option, + /// Execution requirements for the pod. + #[serde(default)] + pub recommend_specs: RecommendSpecs, + /// Optional GPU requirements for the pod. If set, then the running system needs a GPU that meets the requirements. + pub gpu_requirements: Option, } #[uniffi::export] @@ -74,10 +72,8 @@ impl Pod { input_spec: HashMap, output_dir: PathBuf, output_spec: HashMap, - source_commit_url: String, - recommended_cpus: f32, - recommended_memory: u64, - required_gpu: Option, + recommend_specs: RecommendSpecs, + gpu_requirements: Option, ) -> Result { let pod_no_hash = Self { annotation, @@ -87,19 +83,72 @@ impl Pod { input_spec, output_dir, output_spec, - source_commit_url, - recommended_cpus, - recommended_memory, - required_gpu, + recommend_specs, + gpu_requirements, }; Ok(Self { - hash: hash_buffer(to_yaml(&pod_no_hash)?), + hash: hash_buffer(pod_no_hash.to_yaml()?), ..pod_no_hash }) } } +impl ToYaml for Pod { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + match field_name { + "annotation" | "hash" | "recommend_specs" => None, + _ => Some((field_name.to_owned(), field_value.clone())), + } + } +} + +/// Execution recommendations for a pod, since it doesn't impact the actual reproducibility +/// it shouldn't be hashed along with the pod +#[derive(uniffi::Record, Serialize, Deserialize, Debug, PartialEq, Default, Clone)] +pub struct RecommendSpecs { + /// Optimal number of CPU cores needed to run the pod provided by the user + pub cpus: f32, + /// Optimal amount of memory needed to run the pod provided by the user, code can probably run with less but may hit OOM + pub memory: u64, +} + +impl ToYaml for RecommendSpecs { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + Some((field_name.to_owned(), field_value.clone())) + } +} + +/// Specification for GPU requirements in computation. +#[derive(uniffi::Record, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct GPURequirement { + /// GPU model specification. + pub model: GPUModel, + /// Manufacturer recommended memory. + pub recommended_memory: u64, + /// Number of GPU cards required. + pub count: u16, +} + +/// GPU model specification. +#[derive(uniffi::Enum, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum GPUModel { + /// NVIDIA-manufactured card where `String` is the specific minimum CUDA version in X.XX + NVIDIA(String), + /// Any GPU architecture, code is generic enough + Any, +} + /// A compute job that specifies resource requests and input/output targets. +/// +/// `PodJob` represents a specific execution instance of a [`Pod`] with concrete +/// input data, resource limits, and output specifications. It includes all the +/// information needed to run a containerized computation job. #[derive( uniffi::Object, Serialize, Deserialize, Debug, PartialEq, Clone, Default, Display, CloneGetters, )] @@ -177,12 +226,25 @@ impl PodJob { env_vars, }; Ok(Self { - hash: hash_buffer(to_yaml(&pod_job_no_hash)?), + hash: hash_buffer(pod_job_no_hash.to_yaml()?), ..pod_job_no_hash }) } } +impl ToYaml for PodJob { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + match field_name { + "annotation" | "hash" => None, + "pod" => Some((field_name.to_owned(), field_value["hash"].clone())), + _ => Some((field_name.to_owned(), field_value.clone())), + } + } +} + /// Result from a compute job run. #[derive(uniffi::Record, Serialize, Deserialize, Debug, Clone, PartialEq, Default)] pub struct PodResult { @@ -205,6 +267,8 @@ pub struct PodResult { pub created: u64, /// Time in epoch when terminated in seconds. pub terminated: u64, + /// Logs about stdout and stderr, where stderr is append at the end + pub logs: String, } impl PodResult { @@ -221,6 +285,7 @@ impl PodResult { created: u64, terminated: u64, namespace_lookup: &HashMap, + logs: String, ) -> Result { let output_packet = pod_job .pod @@ -239,7 +304,13 @@ impl PodResult { match local_location.try_exists() { Ok(false) => None, - Err(error) => Some(Err(OrcaError::from(error))), + Err(error) => Some(Err(OrcaError { + kind: Kind::InvalidPath { + path: local_location.clone(), + source: error, + backtrace: Some(Backtrace::capture()), + }, + })), Ok(true) => Some(Ok(( packet_key, Blob { @@ -263,6 +334,7 @@ impl PodResult { }) .collect::>()?; + // If packet is completed, the output packet must meet the output spec if matches!(status, PodStatus::Completed) { validate_packet("output".into(), &pod_job.pod.output_spec, &output_packet)?; } @@ -276,30 +348,24 @@ impl PodResult { status, created, terminated, + logs, }; Ok(Self { - hash: hash_buffer(to_yaml(&pod_result_no_hash)?), + hash: hash_buffer(pod_result_no_hash.to_yaml()?), ..pod_result_no_hash }) } } -/// Specification for GPU requirements in computation. -#[derive(uniffi::Record, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub struct GPURequirement { - /// GPU model specification. - pub model: GPUModel, - /// Manufacturer recommended memory. - pub recommended_memory: u64, - /// Number of GPU cards required. - pub count: u16, -} - -/// GPU model specification. -#[derive(uniffi::Enum, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub enum GPUModel { - /// NVIDIA-manufactured card where `String` is the specific model e.g. ??? - NVIDIA(String), - /// AMD-manufactured card where `String` is the specific model e.g. ??? - AMD(String), +impl ToYaml for PodResult { + fn process_field( + field_name: &str, + field_value: &serde_yaml::Value, + ) -> Option<(String, serde_yaml::Value)> { + match field_name { + "annotation" | "hash" => None, + "pod_job" => Some((field_name.to_owned(), field_value["hash"].clone())), + _ => Some((field_name.to_owned(), field_value.clone())), + } + } } diff --git a/src/uniffi/operator.rs b/src/uniffi/operator.rs new file mode 100644 index 00000000..28fbd250 --- /dev/null +++ b/src/uniffi/operator.rs @@ -0,0 +1,39 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::core::model::ToYaml as _; +use crate::core::{crypto::hash_buffer, model::serialize_hashmap}; +use crate::uniffi::error::Result; + +/// Operator class that map `input_keys` to `output_key`, effectively renaming it +/// For use in pipelines +#[derive(uniffi::Object, Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Default)] +pub struct MapOperator { + /// Unique hash of the map operator + #[serde(skip)] + pub hash: String, + /// Mapping of input keys to output keys + #[serde(serialize_with = "serialize_hashmap")] + pub map: HashMap, +} + +#[uniffi::export] +impl MapOperator { + #[uniffi::constructor] + /// Create a new `MapOperator` + /// + /// # Errors + /// Will error if there are issues converting the map to yaml for hashing + pub fn new(map: HashMap) -> Result { + let no_hash = Self { + map, + hash: String::new(), + }; + + Ok(Self { + hash: hash_buffer(no_hash.to_yaml()?), + ..no_hash + }) + } +} diff --git a/src/uniffi/orchestrator/agent.rs b/src/uniffi/orchestrator/agent.rs index 6cab47b5..1d1df2e7 100644 --- a/src/uniffi/orchestrator/agent.rs +++ b/src/uniffi/orchestrator/agent.rs @@ -51,7 +51,7 @@ pub struct AgentClient { /// Connecting agent's assigned name used for reference. pub host: String, #[getset(skip)] - pub(crate) session: zenoh::Session, + pub(crate) session: Arc, } #[uniffi::export] @@ -72,7 +72,8 @@ impl AgentClient { .await .context(selector::AgentCommunicationFailure {})?, ) - })?, + })? + .into(), }) } /// Start many pod jobs to be processed in parallel. @@ -155,15 +156,15 @@ impl Agent { /// # Errors /// /// Will stop and return an error if encounters an error while processing any pod job request. - #[expect(clippy::excessive_nesting, reason = "Nesting manageable.")] pub async fn start( &self, namespace_lookup: &HashMap, available_store: Option>, ) -> Result<()> { let mut services = JoinSet::new(); + let self_ref = Arc::new(self.clone()); services.spawn(start_service( - Arc::new(self.clone()), + Arc::clone(&self_ref), "pod_job", BTreeMap::from([("action", "request".to_owned())]), namespace_lookup.clone(), @@ -185,7 +186,7 @@ impl Agent { "pod_job", BTreeMap::from([ ( - "action", + "event", match &pod_result.status { PodStatus::Completed => "success", PodStatus::Running @@ -204,23 +205,9 @@ impl Agent { )); if let Some(store) = available_store { services.spawn(start_service( - Arc::new(self.clone()), + Arc::clone(&self_ref), "pod_job", - BTreeMap::from([("action", "success".to_owned())]), - namespace_lookup.clone(), - { - let inner_store = Arc::clone(&store); - async move |_, _, _, pod_result| { - inner_store.save_pod_result(&pod_result)?; - Ok(()) - } - }, - async |_, ()| Ok(()), - )); - services.spawn(start_service( - Arc::new(self.clone()), - "pod_job", - BTreeMap::from([("action", "failure".to_owned())]), + BTreeMap::from([("event", "*".to_owned())]), namespace_lookup.clone(), async move |_, _, _, pod_result| { store.save_pod_result(&pod_result)?; diff --git a/src/uniffi/orchestrator/docker.rs b/src/uniffi/orchestrator/docker.rs index a8eb5d22..58114577 100644 --- a/src/uniffi/orchestrator/docker.rs +++ b/src/uniffi/orchestrator/docker.rs @@ -12,7 +12,9 @@ use crate::{ use async_trait; use bollard::{ Docker, - container::{RemoveContainerOptions, StartContainerOptions, WaitContainerOptions}, + container::{ + LogOutput, LogsOptions, RemoveContainerOptions, StartContainerOptions, WaitContainerOptions, + }, errors::Error::DockerContainerWaitError, image::{CreateImageOptions, ImportImageOptions}, }; @@ -70,6 +72,9 @@ impl Orchestrator for LocalDockerOrchestrator { ) -> Result { ASYNC_RUNTIME.block_on(self.get_result(pod_run, namespace_lookup)) } + fn get_logs_blocking(&self, pod_run: &PodRun) -> Result { + ASYNC_RUNTIME.block_on(self.get_logs(pod_run)) + } #[expect( clippy::try_err, reason = r#" @@ -263,8 +268,65 @@ impl Orchestrator for LocalDockerOrchestrator { ), })?, namespace_lookup, + self.get_logs(pod_run).await?, ) } + + async fn get_logs(&self, pod_run: &PodRun) -> Result { + let mut std_out = Vec::new(); + let mut std_err = Vec::new(); + + self.api + .logs::( + &pod_run.assigned_name, + Some(LogsOptions { + stdout: true, + stderr: true, + ..Default::default() + }), + ) + .try_collect::>() + .await? + .iter() + .for_each(|log_output| match log_output { + LogOutput::StdOut { message } => { + std_out.extend(message.to_vec()); + } + LogOutput::StdErr { message } => { + std_err.extend(message.to_vec()); + } + LogOutput::StdIn { .. } | LogOutput::Console { .. } => { + // Ignore stdin logs, as they are not relevant for our use case + } + }); + + let mut logs = String::from_utf8_lossy(&std_out).to_string(); + if !std_err.is_empty() { + logs.push_str("\nSTDERR:\n"); + logs.push_str(&String::from_utf8_lossy(&std_err)); + } + + // Check for errors in the docker state, if exist, attach it to logs + // This is for when the container exits immediately due to a bad command or similar + let error = self + .api + .inspect_container(&pod_run.assigned_name, None) + .await? + .state + .context(selector::FailedToExtractRunInfo { + container_name: &pod_run.assigned_name, + })? + .error + .context(selector::FailedToExtractRunInfo { + container_name: &pod_run.assigned_name, + })?; + + if !error.is_empty() { + logs.push_str(&error); + } + + Ok(logs) + } } #[uniffi::export] diff --git a/src/uniffi/orchestrator/mod.rs b/src/uniffi/orchestrator/mod.rs index 00358ec3..04d2412f 100644 --- a/src/uniffi/orchestrator/mod.rs +++ b/src/uniffi/orchestrator/mod.rs @@ -58,7 +58,7 @@ pub struct PodRunInfo { pub memory_limit: u64, } /// Current computation managed by orchestrator. -#[derive(uniffi::Record, Debug, PartialEq)] +#[derive(uniffi::Record, Debug, PartialEq, Clone)] pub struct PodRun { /// Original compute request. pub pod_job: Arc, @@ -121,6 +121,10 @@ pub trait Orchestrator: Send + Sync + fmt::Debug { pod_run: &PodRun, namespace_lookup: &HashMap, ) -> Result; + /// Get the logs for a specific pod run. + /// # Errors + /// Will return `Err` if there is an issue getting logs. + fn get_logs_blocking(&self, pod_run: &PodRun) -> Result; /// How to asynchronously start containers with an alternate image. /// /// # Errors @@ -170,6 +174,9 @@ pub trait Orchestrator: Send + Sync + fmt::Debug { pod_run: &PodRun, namespace_lookup: &HashMap, ) -> Result; + + /// Get the logs for a specific pod run. + async fn get_logs(&self, pod_run: &PodRun) -> Result; } /// Orchestration execution agent daemon and client. pub mod agent; diff --git a/src/uniffi/store/filestore.rs b/src/uniffi/store/filestore.rs index db089dfa..d5b58b6e 100644 --- a/src/uniffi/store/filestore.rs +++ b/src/uniffi/store/filestore.rs @@ -1,14 +1,27 @@ -use crate::uniffi::{ - error::Result, - model::{ - ModelType, - pod::{Pod, PodJob, PodResult}, +use crate::{ + core::{crypto::hash_buffer, model::ToYaml as _}, + uniffi::{ + error::{Kind, OrcaError, Result, selector}, + model::{ + ModelType, + pipeline::{JOIN_OPERATOR_HASH, Kernel, NodeURI, Pipeline}, + pod::{Pod, PodJob, PodResult}, + }, + operator::MapOperator, + store::{ModelID, ModelInfo, Store}, }, - store::{ModelID, ModelInfo, Store}, }; +use chrono::Utc; use derive_more::Display; use getset::CloneGetters; -use std::{fs, path::PathBuf}; +use serde::Deserialize; +use snafu::OptionExt as _; +use std::{ + backtrace::Backtrace, + collections::{HashMap, HashSet}, + fs, + path::PathBuf, +}; use uniffi; /// Support for a storage backend on a local filesystem directory. #[derive(uniffi::Object, Debug, Display, CloneGetters, Clone)] @@ -23,24 +36,68 @@ pub struct LocalFileStore { #[uniffi::export] impl Store for LocalFileStore { fn save_pod(&self, pod: &Pod) -> Result<()> { - self.save_model(pod, &pod.hash, pod.annotation.as_ref()) + self.save_model(pod, &pod.hash, pod.annotation.as_ref())?; + // Deal with saving the recommended_specs + // Since we are going with a no modify scheme for saving, we will save the latest version as year-month-day-hour-min-second UTC + Self::save_file( + self.make_path::( + &pod.hash, + format!( + "recommended_specs/{}", + Utc::now().format("%Y-%m-%d-%H-%M-%S-%f") + ), + ), + &pod.recommend_specs.to_yaml()?, + ) } + fn load_pod(&self, model_id: &ModelID) -> Result { let (mut pod, annotation, hash) = self.load_model::(model_id)?; pod.annotation = annotation; pod.hash = hash; + // Deal with the recommended_specs by selecting the last saved spec + // List all files in the dir + let folder_path = self.make_path::(&pod.hash, "recommended_specs"); + let mut recommended_specs = fs::read_dir(&folder_path)?; + + let mut latest_spec_file_name = recommended_specs + .next() + .ok_or(OrcaError { + kind: Kind::EmptyDir { + dir: folder_path.clone(), + backtrace: Some(Backtrace::capture()), + }, + })?? + .file_name(); + + for entry in recommended_specs { + let file_name = entry?.file_name(); + if file_name > latest_spec_file_name { + latest_spec_file_name = file_name; + } + } + + // Read the latest_spec and loaded back in + pod.recommend_specs = serde_yaml::from_str(&fs::read_to_string( + folder_path.join(latest_spec_file_name), + )?)?; + Ok(pod) } + fn list_pod(&self) -> Result> { self.list_model::() } + fn delete_pod(&self, model_id: &ModelID) -> Result<()> { self.delete_model::(model_id) } + fn save_pod_job(&self, pod_job: &PodJob) -> Result<()> { self.save_pod(&pod_job.pod)?; self.save_model(pod_job, &pod_job.hash, pod_job.annotation.as_ref()) } + fn load_pod_job(&self, model_id: &ModelID) -> Result { let (mut pod_job, annotation, hash) = self.load_model::(model_id)?; pod_job.annotation = annotation; @@ -50,16 +107,20 @@ impl Store for LocalFileStore { .into(); Ok(pod_job) } + fn list_pod_job(&self) -> Result> { self.list_model::() } + fn delete_pod_job(&self, model_id: &ModelID) -> Result<()> { self.delete_model::(model_id) } + fn save_pod_result(&self, pod_result: &PodResult) -> Result<()> { self.save_pod_job(&pod_result.pod_job)?; self.save_model(pod_result, &pod_result.hash, pod_result.annotation.as_ref()) } + fn load_pod_result(&self, model_id: &ModelID) -> Result { let (mut pod_result, annotation, hash) = self.load_model::(model_id)?; pod_result.annotation = annotation; @@ -69,22 +130,202 @@ impl Store for LocalFileStore { .into(); Ok(pod_result) } + fn list_pod_result(&self) -> Result> { self.list_model::() } + fn delete_pod_result(&self, model_id: &ModelID) -> Result<()> { self.delete_model::(model_id) } + fn delete_annotation(&self, model_type: &ModelType, name: &str, version: &str) -> Result<()> { - let annotation_file = self.make_path( - model_type, - &self.lookup_hash(model_type, name, version)?, - Self::make_annotation_relpath(name, version), - ); - fs::remove_file(&annotation_file)?; + let annotation_file_path = match model_type { + ModelType::Pod => self.make_path::( + &self.lookup_hash::(name, version)?, + Self::make_annotation_relpath(name, version), + ), + ModelType::PodJob => self.make_path::( + &self.lookup_hash::(name, version)?, + Self::make_annotation_relpath(name, version), + ), + ModelType::PodResult => self.make_path::( + &self.lookup_hash::(name, version)?, + Self::make_annotation_relpath(name, version), + ), + }; + fs::remove_file(&annotation_file_path)?; Ok(()) } + + fn save_map_operator(&self, map_operator: &MapOperator) -> Result<()> { + self.save_model(map_operator, &map_operator.hash, None) + } + + fn load_map_operator(&self, hash: &str) -> Result { + let (mut map_operator, _, _) = + self.load_model::(&ModelID::Hash(hash.to_owned()))?; + hash.clone_into(&mut map_operator.hash); + Ok(map_operator) + } + + fn list_map_operator(&self) -> Result> { + self.list_model::().map(|infos| { + infos + .into_iter() + .map(|info| info.hash) + .collect::>() + }) + } + + fn delete_map_operator(&self, hash: &str) -> Result<()> { + self.delete_model::(&ModelID::Hash(hash.to_owned())) + } + + fn save_pipeline(&self, pipeline: &Pipeline) -> Result<()> { + // Save all the kernels first + for kernel in pipeline.get_kernel_lut() { + match kernel { + Kernel::Pod { pod } => self.save_pod(pod)?, + Kernel::JoinOperator => (), // Skip since it's a constant + Kernel::MapOperator { mapper } => self.save_map_operator(mapper)?, + } + } + + // Save the pipeline + self.save_model(pipeline, &pipeline.hash, pipeline.annotation.as_ref())?; + + // Get label mapping and hash it + let labels_lut = pipeline.get_label_lut().collect::>(); + if !labels_lut.is_empty() { + let labels_lut_yaml = serde_yaml::to_string(&labels_lut)?; + let labels_lut_hash = hash_buffer(labels_lut_yaml.as_bytes()); + // Get the latest label file name if exists + if if let Some(latest_label_file_name) = + self.get_latest_pipeline_labels_file_name(&pipeline.hash)? + { + // Check if the hash is the same + if latest_label_file_name.split('-').next_back().context( + selector::FailedToGetLabelHashFromFileName { + file_name: latest_label_file_name.clone(), + }, + )? == labels_lut_hash + { + false + } else { + // Hash is different, thus we need to save the new file + true + } + } else { + // No existing label file, we need to save the new one + true + } { + Self::save_file( + self.make_path::( + &pipeline.hash, + format!( + "labels/{}-{}", + Utc::now().timestamp_millis(), + labels_lut_hash + ), + ), + &labels_lut_yaml, + )?; + } + } + + Ok(()) + } + + fn load_pipeline(&self, model_id: &ModelID) -> Result { + // Load the file into a temp struct + #[derive(Deserialize)] + struct PipelineYaml { + kernel_lut: HashMap>, + dot: String, + input_spec: HashMap>, + output_spec: HashMap, + } + + let (hash, annotation) = self.decode_model_id::(model_id)?; + + // Load it from the file + let pipeline_yaml: PipelineYaml = serde_yaml::from_str(&fs::read_to_string( + self.make_path::(&hash, Self::SPEC_RELPATH), + )?)?; + + // Get all the kernels for the pipeline + let pod_list = self.list_pod()?; + let map_operator_list = self.list_map_operator()?; + let kernels = pipeline_yaml.kernel_lut.into_iter().try_fold( + HashMap::new(), + |mut kernels, (kernel_hash, node_hashes)| { + let kernel = if pod_list.iter().any(|pod_info| pod_info.hash == kernel_hash) { + Kernel::Pod { + pod: self.load_pod(&ModelID::Hash(kernel_hash))?.into(), + } + } else if map_operator_list + .iter() + .any(|map_hash| map_hash == &kernel_hash) + { + Kernel::MapOperator { + mapper: self.load_map_operator(&kernel_hash)?.into(), + } + } else if kernel_hash == *JOIN_OPERATOR_HASH { + Kernel::JoinOperator + } else { + return Err(OrcaError { + kind: Kind::MissingInfo { + details: format!("Unable to find kernel hash {kernel_hash} in store"), + backtrace: Some(Backtrace::capture()), + }, + }); + }; + for node_hash in node_hashes { + kernels.insert(node_hash, kernel.clone()); + } + Ok(kernels) + }, + )?; + + // Create Pipeline from loaded data + let mut pipeline = Pipeline::new( + &pipeline_yaml.dot, + &kernels, + pipeline_yaml.input_spec, + pipeline_yaml.output_spec, + annotation, + )?; + + // Get the latest labels + if let Some(latest_labels_file_name) = + self.get_latest_pipeline_labels_file_name(&pipeline.hash)? + { + let labels_lut: HashMap = + serde_yaml::from_str(&fs::read_to_string(self.make_path::( + &pipeline.hash, + format!("labels/{latest_labels_file_name}"), + ))?)?; + + // Update the nodes with the labels + pipeline.graph.node_indices().for_each(|node_idx| { + if let Some(label) = labels_lut.get(&pipeline.graph[node_idx].hash) { + pipeline.graph[node_idx].label.clone_from(label); + } + }); + } + // Load the labels LUT + Ok(pipeline) + } + + fn list_pipeline(&self) -> Result> { + self.list_model::() + } + + fn delete_pipeline(&self, model_id: &ModelID) -> Result<()> { + self.delete_model::(model_id) + } } #[uniffi::export] diff --git a/src/uniffi/store/mod.rs b/src/uniffi/store/mod.rs index 449da225..ad8cc036 100644 --- a/src/uniffi/store/mod.rs +++ b/src/uniffi/store/mod.rs @@ -2,8 +2,10 @@ use crate::uniffi::{ error::Result, model::{ ModelType, + pipeline::Pipeline, pod::{Pod, PodJob, PodResult}, }, + operator::MapOperator, }; use uniffi; /// Options for identifying a model. @@ -16,7 +18,7 @@ pub enum ModelID { } /// Metadata for a model. -#[derive(uniffi::Record, Debug, PartialEq, Eq)] +#[derive(uniffi::Record, Debug, PartialEq, Eq, Hash)] pub struct ModelInfo { /// A model's name. pub name: Option, @@ -35,6 +37,7 @@ pub trait Store: Send + Sync { /// /// Will return `Err` if there is an issue storing `pod`. fn save_pod(&self, pod: &Pod) -> Result<()>; + /// How to load a stored pod into a model instance. /// /// # Errors @@ -42,82 +45,135 @@ pub trait Store: Send + Sync { /// Will return `Err` if there is an issue loading a pod from the store using `name` and /// `version`. fn load_pod(&self, model_id: &ModelID) -> Result; + /// How to query stored pods. /// /// # Errors - /// /// Will return `Err` if there is an issue querying metadata from existing pods in the store. fn list_pod(&self) -> Result>; + /// How to explicitly delete a stored pod and all associated annotations (does not propagate). /// /// # Errors - /// /// Will return `Err` if there is an issue deleting a pod from the store using `name` and /// `version`. fn delete_pod(&self, model_id: &ModelID) -> Result<()>; + /// How a pod job is stored. /// /// # Errors - /// /// Will return `Err` if there is an issue storing `pod_job`. fn save_pod_job(&self, pod_job: &PodJob) -> Result<()>; + /// How to load a stored pod job into a model instance. /// /// # Errors - /// /// Will return `Err` if there is an issue loading a pod job from the store using `name` and /// `version`. fn load_pod_job(&self, model_id: &ModelID) -> Result; + /// How to query stored pod jobs. /// /// # Errors - /// /// Will return `Err` if there is an issue querying metadata from existing pod jobs in the /// store. fn list_pod_job(&self) -> Result>; + /// How to explicitly delete a stored pod job and all associated annotations (does not /// propagate). /// /// # Errors - /// /// Will return `Err` if there is an issue deleting a pod job from the store using `name` and /// `version`. fn delete_pod_job(&self, model_id: &ModelID) -> Result<()>; + /// How a pod result is stored. /// /// # Errors - /// /// Will return `Err` if there is an issue storing `pod_result`. fn save_pod_result(&self, pod_result: &PodResult) -> Result<()>; + /// How to load a stored pod result into a model instance. /// /// # Errors - /// /// Will return `Err` if there is an issue loading a pod result from the store using `name` and /// `version`. fn load_pod_result(&self, model_id: &ModelID) -> Result; + /// How to query stored pod results. /// /// # Errors - /// /// Will return `Err` if there is an issue querying metadata from existing pod results in the /// store. fn list_pod_result(&self) -> Result>; + /// How to explicitly delete a stored pod result and all associated annotations (does not /// propagate). /// /// # Errors - /// /// Will return `Err` if there is an issue deleting a pod result from the store using `name` and /// `version`. fn delete_pod_result(&self, model_id: &ModelID) -> Result<()>; + /// How to explicitly delete an annotation. /// /// # Errors - /// /// Will return `Err` if there is an issue deleting an annotation from the store using `name` /// and `version`. fn delete_annotation(&self, model_type: &ModelType, name: &str, version: &str) -> Result<()>; + + /// How to save a Mapping + /// + /// # Errors + /// Will return `Err` if there is an issue saving `mapping`. + fn save_map_operator(&self, map_operator: &MapOperator) -> Result<()>; + + /// How to load a stored `MapOperator` + /// + /// # Errors + /// Will return `Err` if there is an issue loading a `MapOperator` from the store + fn load_map_operator(&self, hash: &str) -> Result; + + /// How to query stored `MapOperators` + /// + /// # Errors + /// Will return `Err` if there is an issue querying metadata from existing `MapOperators` in the store. + fn list_map_operator(&self) -> Result>; + + /// How to explicitly delete a stored `MapOperator` + /// + /// # Errors + /// Will return `Err` if there is an issue deleting a `MapOperator` from the store using `hash`. + /// + fn delete_map_operator(&self, hash: &str) -> Result<()>; + + /// How to save a pipeline + /// + /// # Errors + /// Will return `Err` if there is an issue saving `pipeline`. + fn save_pipeline(&self, pipeline: &Pipeline) -> Result<()>; + + /// How to load a stored `Pipeline` + /// + /// # Errors + /// Will return `Err` if there is an issue loading a `Pipeline` from the store + fn load_pipeline(&self, model_id: &ModelID) -> Result; + + /// How to query stored pipelines. + /// + /// # Errors + /// Will return `Err` if there is an issue querying metadata from existing pipelines in the + /// store. + fn list_pipeline(&self) -> Result>; + + /// How to explicitly delete a stored pipeline and all associated annotations (does not + /// propagate). + /// + /// # Errors + /// Will return `Err` if there is an issue deleting a pipeline from the store using `name` and + /// `version`. + fn delete_pipeline(&self, model_id: &ModelID) -> Result<()>; } + /// Store implementation on a local filesystem. pub mod filestore; diff --git a/tests/agent.rs b/tests/agent.rs index 0fa82a0d..59469796 100644 --- a/tests/agent.rs +++ b/tests/agent.rs @@ -9,17 +9,15 @@ pub mod fixture; use fixture::{NAMESPACE_LOOKUP_READ_ONLY, TestDirs, pod_jobs_stresser, pull_image}; -use orcapod::{ - core::orchestrator::agent::extract_metadata, - uniffi::{ - error::Result, - model::pod::PodResult, - orchestrator::{ - agent::{Agent, AgentClient}, - docker::LocalDockerOrchestrator, - }, - store::{ModelID, Store as _, filestore::LocalFileStore}, +use itertools::Itertools as _; +use orcapod::uniffi::{ + error::Result, + model::pod::PodResult, + orchestrator::{ + agent::{Agent, AgentClient}, + docker::LocalDockerOrchestrator, }, + store::{ModelID, Store as _, filestore::LocalFileStore}, }; use std::{ collections::HashMap, @@ -42,7 +40,6 @@ fn simple() -> Result<()> { Ok(()) } -#[expect(clippy::excessive_nesting, reason = "Nesting is manageable")] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn parallel_four_cores() -> Result<()> { let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; @@ -50,7 +47,7 @@ async fn parallel_four_cores() -> Result<()> { // config let image_reference = "ghcr.io/colinianking/stress-ng:e2f96874f951a72c1c83ff49098661f0e013ac40"; pull_image(image_reference)?; - let margin_millis = 2000; + let margin_millis = 5000; let run_duration_secs = 5; let (group, host) = ("agent_parallel-four-cores", "host"); // api @@ -101,8 +98,14 @@ async fn parallel_four_cores() -> Result<()> { .recv_async() .await .expect("All senders have dropped."); - let metadata = extract_metadata(sample.key_expr().as_str()); - let topic_kind = metadata["action"].as_str(); + let metadata: HashMap = sample + .key_expr() + .as_str() + .split('/') + .map(ToOwned::to_owned) + .tuples() + .collect(); + let topic_kind = metadata["event"].as_str(); if ["success", "failure"].contains(&topic_kind) { let pod_result = serde_json::from_slice::(&sample.payload().to_bytes())?; assert!( diff --git a/tests/crypto.rs b/tests/crypto.rs deleted file mode 100644 index a63dc18e..00000000 --- a/tests/crypto.rs +++ /dev/null @@ -1,30 +0,0 @@ -#![expect(missing_docs, clippy::panic_in_result_fn, reason = "OK in tests.")] -pub mod fixture; - -use orcapod::{ - core::crypto::{hash_buffer, hash_dir, hash_file}, - uniffi::error::Result, -}; -use std::fs::read; - -#[test] -fn consistent_hash() -> Result<()> { - let filepath = "./tests/extra/data/images/subject.jpeg"; - assert_eq!( - hash_file(filepath)?, - hash_buffer(&read(filepath)?), - "Checksum not consistent." - ); - Ok(()) -} - -#[test] -fn complex_hash() -> Result<()> { - let dirpath = "./tests/extra/data/images"; - assert_eq!( - hash_dir(dirpath)?, - "6c96a478ea25e34fab045bc82858a2980b2cfb22db32e83c01349a8e7ed3b42c".to_owned(), - "Directory checksum didn't match." - ); - Ok(()) -} diff --git a/tests/error.rs b/tests/error.rs index 7eb89524..4564e162 100644 --- a/tests/error.rs +++ b/tests/error.rs @@ -10,16 +10,13 @@ use chrono::DateTime; use dot_parser::ast::Graph as DOTGraph; use fixture::{NAMESPACE_LOOKUP_READ_ONLY, pod_custom, pod_job_custom, pod_job_style, str_to_vec}; use glob::glob; -use orcapod::{ - core::crypto::hash_file, - uniffi::{ - error::{OrcaError, Result}, - model::packet::PathInfo, - orchestrator::{ - Orchestrator as _, - agent::{AgentClient, Response}, - docker::LocalDockerOrchestrator, - }, +use orcapod::uniffi::{ + error::{OrcaError, Result}, + model::packet::PathInfo, + orchestrator::{ + Orchestrator as _, + agent::{AgentClient, Response}, + docker::LocalDockerOrchestrator, }, }; use serde_json; @@ -154,14 +151,6 @@ fn internal_incomplete_packet() -> Result<()> { Ok(()) } -#[test] -fn internal_invalid_filepath() { - assert!( - hash_file("nonexistent_file.txt").is_err_and(contains_debug), - "Did not raise an invalid filepath error." - ); -} - #[test] fn internal_key_missing() { assert!( diff --git a/tests/extra/data/input_txt/Where.txt b/tests/extra/data/input_txt/Where.txt new file mode 100644 index 00000000..2891a132 --- /dev/null +++ b/tests/extra/data/input_txt/Where.txt @@ -0,0 +1 @@ +Where diff --git a/tests/extra/data/input_txt/black.txt b/tests/extra/data/input_txt/black.txt new file mode 100644 index 00000000..7e66a17d --- /dev/null +++ b/tests/extra/data/input_txt/black.txt @@ -0,0 +1 @@ +black diff --git a/tests/extra/data/input_txt/cat.txt b/tests/extra/data/input_txt/cat.txt new file mode 100644 index 00000000..ef07ddcd --- /dev/null +++ b/tests/extra/data/input_txt/cat.txt @@ -0,0 +1 @@ +cat diff --git a/tests/extra/data/input_txt/hiding.txt b/tests/extra/data/input_txt/hiding.txt new file mode 100644 index 00000000..56e64f05 --- /dev/null +++ b/tests/extra/data/input_txt/hiding.txt @@ -0,0 +1 @@ +hiding diff --git a/tests/extra/data/input_txt/is.txt b/tests/extra/data/input_txt/is.txt new file mode 100644 index 00000000..f5cb1322 --- /dev/null +++ b/tests/extra/data/input_txt/is.txt @@ -0,0 +1 @@ +is diff --git a/tests/extra/data/input_txt/playing.txt b/tests/extra/data/input_txt/playing.txt new file mode 100644 index 00000000..0395b790 --- /dev/null +++ b/tests/extra/data/input_txt/playing.txt @@ -0,0 +1 @@ +playing diff --git a/tests/extra/data/input_txt/tabby.txt b/tests/extra/data/input_txt/tabby.txt new file mode 100644 index 00000000..3de6015d --- /dev/null +++ b/tests/extra/data/input_txt/tabby.txt @@ -0,0 +1 @@ +tabby diff --git a/tests/extra/data/input_txt/the.txt b/tests/extra/data/input_txt/the.txt new file mode 100644 index 00000000..41d25f51 --- /dev/null +++ b/tests/extra/data/input_txt/the.txt @@ -0,0 +1 @@ +the diff --git a/tests/extra/data/output/output.txt b/tests/extra/data/output/output.txt new file mode 100644 index 00000000..ef07ddcd --- /dev/null +++ b/tests/extra/data/output/output.txt @@ -0,0 +1 @@ +cat diff --git a/tests/extra/python/agent_test.py b/tests/extra/python/agent.py similarity index 88% rename from tests/extra/python/agent_test.py rename to tests/extra/python/agent.py index 7e35ea79..8405d8d7 100644 --- a/tests/extra/python/agent_test.py +++ b/tests/extra/python/agent.py @@ -17,6 +17,7 @@ Uri, Pod, Annotation, + RecommendSpecs, ) @@ -29,7 +30,7 @@ def count(sample): with zenoh.open(zenoh.Config()) as session: with session.declare_subscriber( - f"**/action/success/**/group/{group}/**/topic/pod_job/**", count + f"**/event/success/**/group/{group}/**/topic/pod_job/**", count ) as subscriber: await asyncio.sleep(20) # wait for results @@ -45,7 +46,7 @@ async def main(client, agent, test_dir, namespace_lookup, pod_jobs): available_store=LocalFileStore(directory=f"{test_dir}/store"), ), ) - await asyncio.sleep(5) # ensure service ready + await asyncio.sleep(1) # ensure service ready try: await client.start_pod_jobs(pod_jobs=pod_jobs) @@ -88,10 +89,11 @@ async def main(client, agent, test_dir, namespace_lookup, pod_jobs): input_spec={}, output_dir="/tmp/output", output_spec={}, - source_commit_url="https://github.com/user/simple", - recommended_cpus=0.1, - recommended_memory=10 << 20, - required_gpu=None, + recommend_specs=RecommendSpecs( + cpus=0.1, + memory=128 << 20, + ), + gpu_requirements=None, ), input_packet={}, output_dir=Uri( @@ -99,7 +101,7 @@ async def main(client, agent, test_dir, namespace_lookup, pod_jobs): path=".", ), cpu_limit=1, - memory_limit=10 << 20, + memory_limit=128 << 20, env_vars=None, namespace_lookup=namespace_lookup, ) diff --git a/tests/extra/python/smoke_test.py b/tests/extra/python/model.py similarity index 96% rename from tests/extra/python/smoke_test.py rename to tests/extra/python/model.py index 4e111023..03c76d8c 100755 --- a/tests/extra/python/smoke_test.py +++ b/tests/extra/python/model.py @@ -16,6 +16,7 @@ LocalFileStore, ModelId, ModelType, + RecommendSpecs, OrcaError, ) @@ -32,10 +33,11 @@ def create_pod(data, _): input_spec={}, output_dir="/tmp/output", output_spec={}, - source_commit_url="https://github.com/user/simple", - recommended_cpus=0.1, - recommended_memory=10 << 20, - required_gpu=None, + recommend_specs=RecommendSpecs( + cpus=0.1, + memory=10 << 20, + ), + gpu_requirements=None, ) return data["pod"], data diff --git a/tests/fixture/mod.rs b/tests/fixture/mod.rs index a9ae1090..9692ed4a 100644 --- a/tests/fixture/mod.rs +++ b/tests/fixture/mod.rs @@ -13,8 +13,10 @@ use orcapod::uniffi::{ model::{ Annotation, packet::{Blob, BlobKind, Packet, PathInfo, PathSet, URI}, - pod::{Pod, PodJob, PodResult}, + pipeline::{Kernel, NodeURI, Pipeline, PipelineJob}, + pod::{Pod, PodJob, PodResult, RecommendSpecs}, }, + operator::MapOperator, orchestrator::PodStatus, store::{ModelID, ModelInfo, Store}, }; @@ -75,9 +77,10 @@ pub fn pod_style() -> Result { }, ), ]), - "https://github.com/user/style-transfer/tree/1.0.0".to_owned(), - 0.25, // 250 millicores as frac cores - 1_u64 << 30, // 1GiB in bytes + RecommendSpecs { + cpus: 0.25, + memory: 1_u64 << 30, + }, None, ) } @@ -153,6 +156,7 @@ pub fn pod_result_style( 1_737_922_307, 1_737_925_907, namespace_lookup, + "Example logs".to_owned(), ) } @@ -168,9 +172,10 @@ pub fn pod_custom( input_spec, PathBuf::from("/tmp/output"), HashMap::new(), - "https://github.com/place/holder".to_owned(), - 0.1, // 100 millicores as frac cores - 50_u64 << 20, // 10 MiB in bytes + RecommendSpecs { + cpus: 0.1, + memory: 50_u64 << 20, + }, None, ) } @@ -280,6 +285,283 @@ pub fn pull_image(reference: &str) -> Result<()> { Ok(()) } +// Pipeline Fixture +pub fn combine_txt_pod(pod_name: &str) -> Result { + Pod::new( + Some(Annotation { + name: pod_name.to_owned(), + description: "Takes two input files, remove the final next line and combine them" + .to_owned(), + version: "1.0.0".to_owned(), + }), + "alpine:3.14".to_owned(), + vec![ + "sh".into(), + "-c".into(), + format!( + "printf '%s %s\\n' \"$(cat input/input_1.txt | head -c -1)\" \"$(cat input/input_2.txt | head -c -1)\" > /output/output.txt" + ), + ], + HashMap::from([ + ( + "input_1".to_owned(), + PathInfo { + path: PathBuf::from("/input/input_1.txt"), + match_pattern: r".*\.txt".to_owned(), + }, + ), + ( + "input_2".into(), + PathInfo { + path: PathBuf::from("/input/input_2.txt"), + match_pattern: r".*\.txt".to_owned(), + }, + ), + ]), + PathBuf::from("/output"), + HashMap::from([( + "output".to_owned(), + PathInfo { + path: PathBuf::from("output.txt"), + match_pattern: r".*\.txt".to_owned(), + }, + )]), + RecommendSpecs { + cpus: 0.25, + memory: 128_u64 << 20, + }, + None, + ) +} + +#[expect(clippy::too_many_lines, reason = "OK in tests.")] +pub fn pipeline() -> Result { + // Create a simple pipeline where the functions job is to add append their name into the input file + // Structure: A -> Mapper -> Joiner -> B -> Mapper -> C, D -> Mapper -> Joiner + + // Create the kernel map + let mut kernel_map = HashMap::new(); + + // Insert the pod into the kernel map + for pod_name in ["A", "B", "C", "D", "E"] { + kernel_map.insert(pod_name.into(), combine_txt_pod(pod_name)?.into()); + } + + let output_to_input_1 = Arc::new(MapOperator::new(HashMap::from([( + "output".to_owned(), + "input_1".to_owned(), + )]))?); + + let output_to_input_2 = Arc::new(MapOperator::new(HashMap::from([( + "output".to_owned(), + "input_2".to_owned(), + )]))?); + + // Create a mapper for A, B, and C + kernel_map.insert( + "pod_a_mapper".into(), + Kernel::MapOperator { + mapper: Arc::clone(&output_to_input_1), + }, + ); + kernel_map.insert( + "pod_b_mapper".into(), + Kernel::MapOperator { + mapper: Arc::clone(&output_to_input_2), + }, + ); + kernel_map.insert( + "pod_c_mapper".into(), + Kernel::MapOperator { + mapper: Arc::clone(&output_to_input_1), + }, + ); + kernel_map.insert( + "pod_d_mapper".into(), + Kernel::MapOperator { + mapper: Arc::clone(&output_to_input_2), + }, + ); + + for joiner_name in ['c', 'd', 'e'] { + kernel_map.insert(format!("pod_{joiner_name}_joiner"), Kernel::JoinOperator); + } + + // Write all the edges in DOT format + let dot = " + digraph { + A -> pod_a_mapper -> pod_c_joiner; + B -> pod_b_mapper -> pod_c_joiner; + pod_c_joiner -> C -> pod_c_mapper-> pod_e_joiner; + D -> pod_d_mapper -> pod_e_joiner; + pod_e_joiner -> E; + } + "; + + Pipeline::new( + dot, + &kernel_map, + HashMap::from([ + ( + "where".into(), + vec![NodeURI { + node_id: "A".into(), + key: "input_1".into(), + }], + ), + ( + "is".into(), + vec![NodeURI { + node_id: "A".into(), + key: "input_2".into(), + }], + ), + ( + "the".into(), + vec![NodeURI { + node_id: "B".into(), + key: "input_1".into(), + }], + ), + ( + "cat_color".into(), + vec![NodeURI { + node_id: "B".into(), + key: "input_2".into(), + }], + ), + ( + "cat".into(), + vec![NodeURI { + node_id: "D".into(), + key: "input_1".into(), + }], + ), + ( + "action".into(), + vec![NodeURI { + node_id: "D".into(), + key: "input_2".into(), + }], + ), + ]), + HashMap::from([( + "output".to_owned(), + NodeURI { + node_id: "E".into(), + key: "output".into(), + }, + )]), + Some(Annotation { + name: "test".into(), + version: "0.0.0".into(), + description: "Test pipeline".into(), + }), + ) +} + +#[expect(clippy::implicit_hasher, reason = "Could be a false positive?")] +pub fn pipeline_job(namespace_lookup: &HashMap) -> Result { + // Create a simple pipeline_job + let namespace: String = "default".into(); + PipelineJob::new( + pipeline()?.into(), + &HashMap::from([ + ( + "where".into(), + vec![PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/Where.txt".into(), + }, + checksum: String::new(), + })], + ), + ( + "is".into(), + vec![PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/is.txt".into(), + }, + checksum: String::new(), + })], + ), + ( + "the".into(), + vec![PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/the.txt".into(), + }, + checksum: String::new(), + })], + ), + ( + "cat_color".into(), + vec![ + PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/black.txt".into(), + }, + checksum: String::new(), + }), + PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/tabby.txt".into(), + }, + checksum: String::new(), + }), + ], + ), + ( + "cat".into(), + vec![PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/cat.txt".into(), + }, + checksum: String::new(), + })], + ), + ( + "action".into(), + vec![ + PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace: namespace.clone(), + path: "input_txt/hiding.txt".into(), + }, + checksum: String::new(), + }), + PathSet::Unary(Blob { + kind: BlobKind::File, + location: URI { + namespace, + path: "input_txt/playing.txt".into(), + }, + checksum: String::new(), + }), + ], + ), + ]), + URI { + namespace: "default".to_owned(), + path: PathBuf::from("pipeline_output"), + }, + namespace_lookup, + ) +} + // --- util --- pub fn str_to_vec(v: &str) -> Vec { @@ -290,6 +572,8 @@ pub struct TestDirs(pub HashMap); impl TestDirs { pub fn new(config: &HashMap>>) -> Result { + // Check if .tmp exists if not create it + fs::create_dir_all("tests/.tmp")?; Ok(Self( config .iter() @@ -418,3 +702,28 @@ impl TestSetup for PodResult { store.list_pod_result() } } + +impl TestSetup for Pipeline { + fn save(&self, store: &impl Store) -> Result<()> { + store.save_pipeline(self) + } + fn delete(&self, store: &impl Store) -> Result<()> { + store.delete_pipeline(&ModelID::Hash(self.hash.clone())) + } + fn load(&self, store: &impl Store) -> Result { + let annotation = self.annotation.as_ref().expect("Annotation missing."); + store.load_pipeline(&ModelID::Annotation( + annotation.name.clone(), + annotation.version.clone(), + )) + } + fn get_annotation(&self) -> Option<&Annotation> { + self.annotation.as_ref() + } + fn get_hash(&self) -> &str { + &self.hash + } + fn list(&self, store: &impl Store) -> Result> { + store.list_pipeline() + } +} diff --git a/tests/model.rs b/tests/model.rs deleted file mode 100644 index 31d0ce8c..00000000 --- a/tests/model.rs +++ /dev/null @@ -1,141 +0,0 @@ -#![expect(missing_docs, clippy::panic_in_result_fn, reason = "OK in tests.")] - -pub mod fixture; -use fixture::{NAMESPACE_LOOKUP_READ_ONLY, pod_job_style, pod_result_style, pod_style}; -use indoc::indoc; -use orcapod::{core::model::to_yaml, uniffi::error::Result}; -use pretty_assertions::assert_eq as pretty_assert_eq; - -#[test] -fn hash_pod() -> Result<()> { - assert_eq!( - pod_style()?.hash, - "0e993f645fbb36f0635e2c9140975997cf4ca723d0b49cf4ee4963b76e6424d7", - "Hash didn't match." - ); - Ok(()) -} - -#[test] -fn pod_to_yaml() -> Result<()> { - pretty_assert_eq!( - to_yaml(&pod_style()?)?, - indoc! {r" - class: pod - image: example.server.com/user/style-transfer:1.0.0 - command: - - python - - /run.py - input_spec: - base-input: - path: /input - match_pattern: input/.* - extra-style: - path: /extra_styles/style2.t7 - match_pattern: .*\.t7 - output_dir: /output - output_spec: - result1: - path: result1.jpeg - match_pattern: .*\.jpeg - result2: - path: result2.jpeg - match_pattern: .*\.jpeg - source_commit_url: https://github.com/user/style-transfer/tree/1.0.0 - recommended_cpus: 0.25 - recommended_memory: 1073741824 - required_gpu: null - "}, - "YAML serialization didn't match." - ); - Ok(()) -} - -#[test] -fn hash_pod_job() -> Result<()> { - assert_eq!( - pod_job_style(&NAMESPACE_LOOKUP_READ_ONLY)?.hash, - "ba1c4693f9186ccb1b6e63625085d8fd95552b28b7a60fe9b1b47f68a9ba8880", - "Hash didn't match." - ); - Ok(()) -} - -#[test] -fn pod_job_to_yaml() -> Result<()> { - pretty_assert_eq!( - to_yaml(&pod_job_style(&NAMESPACE_LOOKUP_READ_ONLY)?)?, - indoc! {" - class: pod_job - pod: 0e993f645fbb36f0635e2c9140975997cf4ca723d0b49cf4ee4963b76e6424d7 - input_packet: - base-input: - - kind: File - location: - namespace: default - path: styles/style1.t7 - checksum: 69e709c1697e290994d2da75ddfb2097bf801a9436a3727a282e0230e703da2b - - kind: File - location: - namespace: default - path: images/subject.jpeg - checksum: 8b44b8ea83b1f5eec3ac16cf941767e629896c465803fb69c21adbbf984516bd - extra-style: - kind: File - location: - namespace: default - path: styles/mosaic.t7 - checksum: fbd7d882e9e02aafb57366e726762025ff6b2e12cd41abd44b874542b7693771 - output_dir: - namespace: default - path: output - cpu_limit: 0.5 - memory_limit: 2147483648 - env_vars: - AAA: SORT - ZZZ: PLEASE - "}, - "YAML serialization didn't match." - ); - Ok(()) -} - -#[test] -fn hash_pod_result() -> Result<()> { - assert_eq!( - pod_result_style(&NAMESPACE_LOOKUP_READ_ONLY)?.hash, - "e752d86d4fc5435bfa4564ba951851530f5cf2228586c2f488c2ca9e7bcc7ed1", - "Hash didn't match." - ); - Ok(()) -} - -#[test] -fn pod_result_to_yaml() -> Result<()> { - pretty_assert_eq!( - to_yaml(&pod_result_style(&NAMESPACE_LOOKUP_READ_ONLY)?)?, - indoc! {" - class: pod_result - pod_job: ba1c4693f9186ccb1b6e63625085d8fd95552b28b7a60fe9b1b47f68a9ba8880 - output_packet: - result1: - kind: File - location: - namespace: default - path: output/result1.jpeg - checksum: 5898ca5bed67147680c6489056cbf2e90074bc51d8ca2645453742580ce74b7a - result2: - kind: File - location: - namespace: default - path: output/result2.jpeg - checksum: a1458fc7d7d9d23a66feae88b5a89f1756055bdbb6be02fdf672f7d31ed92735 - assigned_name: simple-endeavour - status: Completed - created: 1737922307 - terminated: 1737925907 - "}, - "YAML serialization didn't match." - ); - Ok(()) -} diff --git a/tests/operator.rs b/tests/operator.rs deleted file mode 100644 index 0389f069..00000000 --- a/tests/operator.rs +++ /dev/null @@ -1,213 +0,0 @@ -#![expect(missing_docs, clippy::panic_in_result_fn, reason = "OK in tests.")] - -use orcapod::{ - core::operator::{JoinOperator, MapOperator, Operator}, - uniffi::{ - error::Result, - model::packet::{Blob, BlobKind, Packet, PathSet, URI}, - }, -}; -use std::{collections::HashMap, path::PathBuf}; - -fn make_packet_key(key_name: String, filepath: String) -> (String, PathSet) { - ( - key_name, - PathSet::Unary(Blob { - kind: BlobKind::File, - location: URI { - namespace: "default".into(), - path: PathBuf::from(filepath), - }, - checksum: String::new(), - }), - ) -} - -async fn next_batch( - operator: impl Operator, - packets: Vec<(String, Packet)>, -) -> Result> { - let mut next_packets = vec![]; - for (stream_name, packet) in packets { - next_packets.extend(operator.next(stream_name, packet).await?); - } - Ok(next_packets) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn join_once() -> Result<()> { - let operator = JoinOperator::new(2); - - let left_stream = (0..3) - .map(|i| { - ( - "left".into(), - Packet::from([make_packet_key( - "subject".into(), - format!("left/subject{i}.png"), - )]), - ) - }) - .collect::>(); - - let right_stream = (0..2) - .map(|i| { - ( - "right".into(), - Packet::from([make_packet_key( - "style".into(), - format!("right/style{i}.t7"), - )]), - ) - }) - .collect::>(); - - let mut input_streams = left_stream; - input_streams.extend(right_stream); - - assert_eq!( - next_batch(operator, input_streams).await?, - vec![ - Packet::from([ - make_packet_key("subject".into(), "left/subject0.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject1.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject2.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject0.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject1.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject2.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - ], - "Unexpected streams." - ); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn join_spotty() -> Result<()> { - let operator = JoinOperator::new(2); - - assert_eq!( - operator - .next( - "right".into(), - Packet::from([make_packet_key("style".into(), "right/style0.t7".into())]) - ) - .await?, - vec![], - "Unexpected streams." - ); - - assert_eq!( - operator - .next( - "right".into(), - Packet::from([make_packet_key("style".into(), "right/style1.t7".into())]) - ) - .await?, - vec![], - "Unexpected streams." - ); - - assert_eq!( - operator - .next( - "left".into(), - Packet::from([make_packet_key( - "subject".into(), - "left/subject0.png".into() - )]) - ) - .await?, - vec![ - Packet::from([ - make_packet_key("subject".into(), "left/subject0.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject0.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - ], - "Unexpected streams." - ); - - assert_eq!( - next_batch( - operator, - (1..3) - .map(|i| { - ( - "left".into(), - Packet::from([make_packet_key( - "subject".into(), - format!("left/subject{i}.png"), - )]), - ) - }) - .collect::>() - ) - .await?, - vec![ - Packet::from([ - make_packet_key("subject".into(), "left/subject1.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject1.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject2.png".into()), - make_packet_key("style".into(), "right/style0.t7".into()), - ]), - Packet::from([ - make_packet_key("subject".into(), "left/subject2.png".into()), - make_packet_key("style".into(), "right/style1.t7".into()), - ]), - ], - "Unexpected streams." - ); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn map_once() -> Result<()> { - let operator = MapOperator::new(&HashMap::from([("key_old".into(), "key_new".into())])); - - assert_eq!( - operator - .next( - "parent".into(), - Packet::from([ - make_packet_key("key_old".into(), "some/key.txt".into()), - make_packet_key("subject".into(), "some/subject.txt".into()), - ]), - ) - .await?, - vec![Packet::from([ - make_packet_key("key_new".into(), "some/key.txt".into()), - make_packet_key("subject".into(), "some/subject.txt".into()), - ]),], - "Unexpected packet." - ); - - Ok(()) -} diff --git a/tests/orchestrator.rs b/tests/orchestrator.rs index d129928b..a1ba528c 100644 --- a/tests/orchestrator.rs +++ b/tests/orchestrator.rs @@ -4,7 +4,6 @@ clippy::panic, clippy::expect_used, clippy::unwrap_used, - clippy::indexing_slicing, reason = "OK in tests." )] @@ -16,6 +15,7 @@ use orcapod::uniffi::{ model::packet::URI, orchestrator::{ImageKind, Orchestrator, PodRun, PodStatus, docker::LocalDockerOrchestrator}, }; +use pretty_assertions::assert_eq; use std::{collections::HashMap, path::PathBuf}; use crate::fixture::{ @@ -49,12 +49,12 @@ fn basic_test( assert_eq!( orchestrator .list_blocking()? - .iter() - .filter(|run| **run == *pod_run) - .map(|run| Ok(orchestrator.get_info_blocking(run)?.command)) - .collect::>>()?[0], - expected_command, - "List return a pod_run with a different command." + .into_iter() + .filter(|pod_run_from_list| pod_run_from_list == pod_run) + .map(|run| Ok(orchestrator.get_info_blocking(&run)?.command)) + .collect::>>()?, + vec![expected_command], + "Unexpected list." ); // await result let pod_result_1 = orchestrator.get_result_blocking(pod_run, namespace_lookup)?; @@ -68,11 +68,11 @@ fn basic_test( orchestrator .list_blocking()? .into_iter() - .filter(|run| *run == *pod_run) + .filter(|pod_run_from_list| pod_run_from_list == pod_run) .map(|run| Ok(orchestrator.get_info_blocking(&run)?.command)) - .collect::>>()?[0], - expected_command, - "List return a pod_run with a different command." + .collect::>>()?, + vec![expected_command], + "Unexpected list." ); assert_eq!( pod_result_1.assigned_name, pod_run.assigned_name, @@ -296,3 +296,44 @@ async fn verify_pod_result_not_running() -> Result<()> { ); Ok(()) } + +#[test] +fn logs() -> Result<()> { + execute_wrapper(|orchestrator, namespace_lookup| { + let pod_job = pod_job_custom( + pod_custom( + "alpine:3.14", + vec!["bin/sh".into(), "-c".into(), "echo \"hi\"".into()], + HashMap::new(), + )?, + HashMap::new(), + namespace_lookup, + )?; + + let pod_run = orchestrator.start_blocking(&pod_job, namespace_lookup)?; + let pod_result = orchestrator.get_result_blocking(&pod_run, namespace_lookup)?; + + assert_eq!( + pod_result.status, + PodStatus::Completed, + "Pod status is not completed" + ); + + assert_eq!(orchestrator.get_logs_blocking(&pod_run)?, "hi\n"); + assert_eq!( + orchestrator + .get_result_blocking(&pod_run, namespace_lookup)? + .logs, + "hi\n" + ); + + orchestrator.delete_blocking(&pod_run)?; + + assert!( + !orchestrator.list_blocking()?.contains(&pod_run), + "Unexpected container remains." + ); + + Ok(()) + }) +} diff --git a/tests/pipeline.rs b/tests/pipeline.rs index caa02d4a..54f512a6 100644 --- a/tests/pipeline.rs +++ b/tests/pipeline.rs @@ -3,6 +3,7 @@ clippy::panic_in_result_fn, clippy::indexing_slicing, clippy::panic, + clippy::type_complexity, reason = "OK in tests." )] @@ -13,11 +14,154 @@ use orcapod::uniffi::{ error::Result, model::{ packet::{Blob, BlobKind, PathInfo, PathSet, URI}, - pipeline::{Kernel, Pipeline, PipelineJob, SpecURI}, + pipeline::{Kernel, NodeURI, Pipeline, PipelineJob}, }, }; +use pretty_assertions::assert_eq; use std::collections::HashMap; +use crate::fixture::{combine_txt_pod, pipeline}; + +#[expect(clippy::too_many_lines, reason = "Test code")] +#[test] +fn preprocessing() -> Result<()> { + let pipeline = pipeline()?; + + // Assert that every node has a non-empty hash + let node_hashes = pipeline + .graph + .node_indices() + .map(|idx| { + ( + pipeline.graph[idx].label.as_str(), + pipeline.graph[idx].hash.as_str(), + ) + }) + .collect::>(); + + assert_eq!( + node_hashes, + HashMap::from([ + ( + "pod_c_joiner", + "d2141ce0c203a8b556d7dbbbc6268ac4bbfa444748f92baff42235787f2b7550" + ), + ( + "B", + "964ebb9ddd6bb7db56e53c19e9ac34dfd08779a656295b01e70b5973adc61103" + ), + ( + "C", + "96b30227e0243f282f7a898bd85a246127e664635a3969577932d7653cfb79cb" + ), + ( + "pod_a_mapper", + "83bd3d17026c882db6b6cca7ccca0173f478c11449cfa8bfb13a0518a7e5e32a" + ), + ( + "pod_b_mapper", + "dd73cd3ab345917b25fc028131d83da7ce1c53702fcbabdd19b86a8bdde158b3" + ), + ( + "pod_d_mapper", + "d37f595093e8f7235f97213b3f7ff88b12786e48ec4f22275018cc7d22c113f8" + ), + ( + "A", + "8e43dbc9fd55fa7d1a36fc4a6c036f4113b7aa7fcf38646a2f2472bac6774962" + ), + ( + "E", + "6ec68cc43ea15472731a318584cc8792fb2ff93c96fed6f3f998849b75976694" + ), + ( + "D", + "04cb341a09eeb771846377405a5f33d011f99a7dfa4739fd7876a7e70c994e4e" + ), + ( + "pod_c_mapper", + "240c8e7fa5e0bd88239aba625387ea495fc5323a5d4b6b519946b8f8b907ddf6" + ), + ( + "pod_e_joiner", + "36f3e88889ecf89183205f340043de61f3c6a254026aae5aa1ce587a666e8c30" + ), + ]), + "Node hashes did not match" + ); + + // Check if the input spec contains the correct node hashes + assert_eq!( + pipeline.input_spec, + HashMap::from([ + ( + "the".into(), + vec![NodeURI { + node_id: "964ebb9ddd6bb7db56e53c19e9ac34dfd08779a656295b01e70b5973adc61103" + .into(), + key: "input_1".into(), + },] + ), + ( + "where".into(), + vec![NodeURI { + node_id: "8e43dbc9fd55fa7d1a36fc4a6c036f4113b7aa7fcf38646a2f2472bac6774962" + .into(), + key: "input_1".into(), + },] + ), + ( + "cat_color".into(), + vec![NodeURI { + node_id: "964ebb9ddd6bb7db56e53c19e9ac34dfd08779a656295b01e70b5973adc61103" + .into(), + key: "input_2".into(), + },] + ), + ( + "is".into(), + vec![NodeURI { + node_id: "8e43dbc9fd55fa7d1a36fc4a6c036f4113b7aa7fcf38646a2f2472bac6774962" + .into(), + key: "input_2".into(), + },] + ), + ( + "cat".into(), + vec![NodeURI { + node_id: "04cb341a09eeb771846377405a5f33d011f99a7dfa4739fd7876a7e70c994e4e" + .into(), + key: "input_1".into(), + },] + ), + ( + "action".into(), + vec![NodeURI { + node_id: "04cb341a09eeb771846377405a5f33d011f99a7dfa4739fd7876a7e70c994e4e" + .into(), + key: "input_2".into(), + },] + ), + ]), + "Input spec did not match" + ); + + // Check if the output spec contain the correct node hashes + assert_eq!( + pipeline.output_spec, + HashMap::from([( + "output".into(), + NodeURI { + node_id: "6ec68cc43ea15472731a318584cc8792fb2ff93c96fed6f3f998849b75976694".into(), + key: "output".into(), + } + ),]), + "Output spec did not match" + ); + + Ok(()) +} + #[test] fn input_packet_checksum() -> Result<()> { let pipeline = Pipeline::new( @@ -26,10 +170,10 @@ fn input_packet_checksum() -> Result<()> { A } "}, - HashMap::from([( + &HashMap::from([( "A".into(), Kernel::Pod { - r#ref: pod_custom( + pod: pod_custom( "alpine:3.14", vec!["echo".into()], HashMap::from([( @@ -43,14 +187,15 @@ fn input_packet_checksum() -> Result<()> { .into(), }, )]), - &HashMap::from([( + HashMap::from([( "pipeline_key_1".into(), - vec![SpecURI { - node: "A".into(), + vec![NodeURI { + node_id: "A".into(), key: "node_key_1".into(), }], )]), - &HashMap::new(), + HashMap::new(), + None, )?; let pipeline_job = PipelineJob::new( @@ -66,7 +211,7 @@ fn input_packet_checksum() -> Result<()> { checksum: String::new(), }])], )]), - &URI { + URI { namespace: "default".into(), path: "output/pipeline".into(), }, @@ -83,5 +228,139 @@ fn input_packet_checksum() -> Result<()> { "8b44b8ea83b1f5eec3ac16cf941767e629896c465803fb69c21adbbf984516bd".to_owned(), "Incorrect checksum" ); + + Ok(()) +} + +/// Testing invalid conditions to make sure validation works +fn basic_pipeline_components() -> Result<( + String, + HashMap, + HashMap>, + HashMap, +)> { + let dot = indoc! {" + digraph { + A + } + "}; + + let metadata = HashMap::from([("A".into(), combine_txt_pod("A")?.into())]); + + let input_spec = HashMap::from([ + ( + "input_1".into(), + vec![NodeURI { + node_id: "A".into(), + key: "input_1".into(), + }], + ), + ( + "input_2".into(), + vec![NodeURI { + node_id: "A".into(), + key: "input_2".into(), + }], + ), + ]); + + let output_spec = HashMap::from([( + "output".into(), + NodeURI { + node_id: "A".into(), + key: "output".into(), + }, + )]); + + Ok((dot.to_owned(), metadata, input_spec, output_spec)) +} + +#[test] +fn invalid_input_spec() -> Result<()> { + let (dot, metadata, _, output_spec) = basic_pipeline_components()?; + + // Test invalid node reference in input_spec + assert!( + Pipeline::new( + &dot, + &metadata, + HashMap::from([( + "input_1".into(), + vec![NodeURI { + node_id: "B".into(), + key: "input_1".into(), + }], + )]), + output_spec.clone(), + None + ) + .is_err(), + "Pipeline creation should have failed due to invalid input_spec" + ); + + // Test invalid key reference in input_spec + assert!( + Pipeline::new( + &dot, + &metadata, + HashMap::from([( + "input_1".into(), + vec![NodeURI { + node_id: "A".into(), + key: "input_3".into(), + }], + )]), + output_spec, + None + ) + .is_err(), + "Pipeline creation should have failed due to invalid input_spec" + ); + + Ok(()) +} + +#[test] +fn invalid_output_spec() -> Result<()> { + let (dot, metadata, input_spec, _) = basic_pipeline_components()?; + + // Test invalid output_spec node reference + assert!( + Pipeline::new( + &dot, + &metadata, + input_spec.clone(), + HashMap::from([( + "A".into(), + NodeURI { + node_id: "B".into(), + key: "output".into(), + } + )]), + None + ) + .is_err(), + "Pipeline creation should have failed due to invalid output_spec" + ); + + // Test invalid output_spec key reference + assert!( + Pipeline::new( + &dot, + &metadata, + input_spec, + HashMap::from([( + "A".into(), + NodeURI { + node_id: "A".into(), + key: "output_dne".into(), + } + )]), + None + ) + .is_err(), + "Pipeline creation should have failed due to invalid output_spec" + ); + Ok(()) } diff --git a/tests/pipeline_runner.rs b/tests/pipeline_runner.rs new file mode 100644 index 00000000..f8f47248 --- /dev/null +++ b/tests/pipeline_runner.rs @@ -0,0 +1,91 @@ +#![expect( + missing_docs, + clippy::panic_in_result_fn, + clippy::indexing_slicing, + clippy::unwrap_used, + reason = "OK in tests." +)] +pub mod fixture; + +// Example for a local module: +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use crate::fixture::TestDirs; +use fixture::pipeline_job; +use orcapod::{ + core::pipeline_runner::DockerPipelineRunner, + uniffi::{ + error::Result, + model::pipeline::PipelineStatus, + orchestrator::{agent::Agent, docker::LocalDockerOrchestrator}, + }, +}; +use tokio::fs::read_to_string; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn basic_run() -> Result<()> { + // Create the test_dir and get the namespace lookup + let test_dirs = TestDirs::new(&HashMap::from([( + "default".to_owned(), + Some("./tests/extra/data/"), + )]))?; + let namespace_lookup = test_dirs.namespace_lookup(); + + // Create and agent and start it (temporary for now, will be merge later) + let agent = Arc::new(Agent::new( + "test:basic_run".to_owned(), + "localhost".to_owned(), + Arc::new(LocalDockerOrchestrator::new().unwrap()), + )?); + + let agent_inner = Arc::clone(&agent); + let namespace_lookup_inner = namespace_lookup.clone(); + tokio::spawn(async move { + agent_inner + .start(&namespace_lookup_inner, None) + .await + .unwrap(); + }); + + let pipeline_job = pipeline_job(&namespace_lookup)?; + + // Create the runner + let mut runner = DockerPipelineRunner::new(agent); + + let pipeline_run = runner + .start(pipeline_job, "default", &namespace_lookup) + .await?; + + // Wait for the pipeline run to complete + let pipeline_result = runner.get_result(&pipeline_run).await?; + + // Check the output packet content + assert_eq!(pipeline_result.output_packets["output"].len(), 4); + + // Check the status + assert_eq!(pipeline_result.status, PipelineStatus::Succeeded); + + // Get all the output file content and read them in + let mut output_content = HashSet::new(); + + for output_packet in &pipeline_result.output_packets["output"] { + output_content + .insert(read_to_string(&output_packet.to_path_buf(&namespace_lookup)?[0]).await?); + } + + // Check if the output_content matches + assert_eq!( + output_content, + HashSet::from([ + "Where is the black cat playing\n".to_owned(), + "Where is the black cat hiding\n".to_owned(), + "Where is the tabby cat playing\n".to_owned(), + "Where is the tabby cat hiding\n".to_owned(), + ]) + ); + + Ok(()) +} diff --git a/tests/store.rs b/tests/store.rs index 3a6c66f8..4552c03e 100644 --- a/tests/store.rs +++ b/tests/store.rs @@ -3,6 +3,7 @@ missing_docs, clippy::panic_in_result_fn, clippy::indexing_slicing, + clippy::unwrap_used, reason = "OK in tests." )] @@ -10,16 +11,34 @@ pub mod fixture; use fixture::{ NAMESPACE_LOOKUP_READ_ONLY, TestDirs, TestSetup, pod_job_style, pod_result_style, pod_style, }; -use orcapod::{ - core::{crypto::hash_buffer, model::to_yaml}, - uniffi::{ - error::Result, - model::{Annotation, ModelType}, - store::{ModelID, ModelInfo, Store as _, filestore::LocalFileStore}, +use orcapod::uniffi::{ + error::Result, + model::{ + Annotation, ModelType, + packet::PathInfo, + pod::{Pod, RecommendSpecs}, }, + operator::MapOperator, + store::{ModelID, ModelInfo, Store as _, filestore::LocalFileStore}, }; -use pretty_assertions::assert_eq as pretty_assert_eq; -use std::{collections::HashMap, fmt::Debug, ops::Deref as _, path::Path, sync::Arc}; +use pretty_assertions::assert_eq; +use std::{ + collections::{HashMap, HashSet}, + fmt::Debug, + ops::Deref as _, + path::{Path, PathBuf}, + sync::Arc, + vec, +}; + +use crate::fixture::{pipeline, str_to_vec}; + +fn get_store_fixtures() -> (TestDirs, LocalFileStore) { + let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)])) + .expect("Failed to create test directories."); + let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + (test_dirs, store) +} fn is_dir_empty(file: &Path, levels_up: usize) -> Option { Some( @@ -33,11 +52,11 @@ fn is_dir_empty(file: &Path, levels_up: usize) -> Option { } fn basic_test(model: &T, expected_model: &T) -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + let (_, store) = get_store_fixtures(); model.save(&store)?; let annotation = model.get_annotation().expect("Annotation missing."); - pretty_assert_eq!( + + assert_eq!( model.list(&store)?, vec![ ModelInfo { @@ -53,7 +72,7 @@ fn basic_test(model: &T, expected_model: &T) - ], "List didn't match." ); - pretty_assert_eq!( + assert_eq!( &model.load(&store)?, expected_model, "Loaded model doesn't match." @@ -74,6 +93,7 @@ fn pod_basic() -> Result<()> { fn pod_job_basic() -> Result<()> { let mut expected_model = pod_job_style(&NAMESPACE_LOOKUP_READ_ONLY)?; let mut pod = expected_model.pod.deref().clone(); + pod.annotation = None; expected_model.pod = Arc::new(pod); basic_test( @@ -101,19 +121,17 @@ fn pod_result_basic() -> Result<()> { #[test] fn pod_files() -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + let (_, store) = get_store_fixtures(); let pod_style = pod_style()?; let annotation = pod_style .annotation .as_ref() .expect("Annotation missing from `pod_style`"); - let annotation_file = store.make_path( - &pod_style, + let annotation_file = store.make_path::( &pod_style.hash, LocalFileStore::make_annotation_relpath(&annotation.name, &annotation.version), ); - let spec_file = store.make_path(&pod_style, &pod_style.hash, LocalFileStore::SPEC_RELPATH); + let spec_file = store.make_path::(&pod_style.hash, LocalFileStore::SPEC_RELPATH); store.save_pod(&pod_style)?; assert!(spec_file.exists(), "Spec file missing."); @@ -139,20 +157,18 @@ fn pod_files() -> Result<()> { #[test] fn pod_list_empty() -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + let (_, store) = get_store_fixtures(); assert_eq!(store.list_pod()?, vec![], "Pod list is not empty."); Ok(()) } #[test] fn pod_load_from_hash() -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + let (_, store) = get_store_fixtures(); let mut pod = pod_style()?; store.save_pod(&pod)?; pod.annotation = None; - pretty_assert_eq!( + assert_eq!( store.load_pod(&ModelID::Hash(pod.hash.clone()))?, pod, "Loaded model from hash doesn't match." @@ -162,8 +178,7 @@ fn pod_load_from_hash() -> Result<()> { #[test] fn pod_annotation_delete() -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); + let (_, store) = get_store_fixtures(); let mut pod = pod_style()?; store.save_pod(&pod)?; let model_version = &pod.annotation.as_ref().map(|x| x.version.clone()); @@ -175,7 +190,7 @@ fn pod_annotation_delete() -> Result<()> { description: String::new(), }); store.save_pod(&pod)?; - pretty_assert_eq!( + assert_eq!( store.list_pod()?, vec![ ModelInfo { @@ -198,7 +213,7 @@ fn pod_annotation_delete() -> Result<()> { ); // case 2: delete new annotation, assert list gives 2 entries: hash, annotation (original). store.delete_annotation(&ModelType::Pod, "new-name", "0.5.0")?; - pretty_assert_eq!( + assert_eq!( store.list_pod()?, vec![ ModelInfo { @@ -222,7 +237,7 @@ fn pod_annotation_delete() -> Result<()> { .to_owned() .expect("Version missing from `pod_style`"), )?; - pretty_assert_eq!( + assert_eq!( store.list_pod()?, vec![ModelInfo { name: None, @@ -241,86 +256,170 @@ fn pod_annotation_delete() -> Result<()> { Ok(()) } +#[expect(clippy::too_many_lines, reason = "Okay because of creating pods")] #[test] fn pod_annotation_unique() -> Result<()> { - let test_dirs = TestDirs::new(&HashMap::from([("default".to_owned(), None::)]))?; - let store = LocalFileStore::new(test_dirs.0["default"].path().to_path_buf()); - let original_annotation = Annotation { - name: "example".to_owned(), + let (_, store) = get_store_fixtures(); + + // Pod values + let annotation = Annotation { + name: "style-transfer".to_owned(), + description: "This is an example pod.".to_owned(), version: "1.0.0".to_owned(), - description: "original".to_owned(), }; - let mut pod = pod_style()?; - pod.annotation = Some(original_annotation.clone()); - store.save_pod(&pod)?; - let original_hash = pod.hash.clone(); - // case 1: Only change description, should skip saving model and annotation - pod.annotation = Some(Annotation { - description: "new".to_owned(), - ..original_annotation.clone() - }); + let image = "example.server.com/user/style-transfer:1.0.0".to_owned(); + let command = str_to_vec("python /run.py"); + let input_spec = HashMap::from([( + "input_key_1".to_owned(), + PathInfo { + path: PathBuf::from("/input"), + match_pattern: "input/.*".to_owned(), + }, + )]); + let output_spec = HashMap::from([( + "output_key_1".to_owned(), + PathInfo { + path: PathBuf::from("/output"), + match_pattern: "output/.*".to_owned(), + }, + )]); + let output_dir: PathBuf = "/output".into(); + let exec_requirements = RecommendSpecs { + cpus: 0.25, + memory: 1_u64 << 30, + }; + + let gpu_requirements = None; + + let pod = Pod::new( + Some(annotation.clone()), + image.clone(), + command.clone(), + input_spec.clone(), + output_dir.clone(), + output_spec.clone(), + exec_requirements.clone(), + gpu_requirements.clone(), + )?; + + // Save pod above store.save_pod(&pod)?; - pretty_assert_eq!( - store.list_pod()?, - vec![ + // case 1: Only change description, should skip saving model and annotation since overriding annotation is not allowed + let pod_with_new_annotation = Pod::new( + Some(Annotation { + description: "new description".into(), + ..pod.annotation.as_ref().unwrap().clone() + }), + "example.server.com/user/style-transfer:1.0.0".to_owned(), + command, + input_spec.clone(), + "/output".into(), + output_spec.clone(), + exec_requirements.clone(), + gpu_requirements.clone(), + )?; + + store.save_pod(&pod_with_new_annotation)?; + assert_eq!( + HashSet::from_iter(store.list_pod()?), + HashSet::from([ ModelInfo { - name: Some(original_annotation.name.clone()), - version: Some(original_annotation.version.clone()), - hash: original_hash.clone(), + name: Some(annotation.name.clone()), + version: Some(annotation.version.clone()), + hash: pod_with_new_annotation.hash, }, ModelInfo { name: None, version: None, - hash: original_hash.clone(), + hash: pod.hash.clone(), }, - ], + ]), "Pod list didn't return 2 expected entries." ); - pretty_assert_eq!( + assert_eq!( store .load_pod(&ModelID::Annotation( - original_annotation.name.clone(), - original_annotation.version.clone() + annotation.name.clone(), + annotation.version.clone() ))? .annotation, - Some(original_annotation.clone()), + Some(annotation.clone()), "Pod annotation unexpected." ); // case 2: Change description + model, should save model but skip annotation - pod.output_dir = "/output_2".into(); - pod.hash = hash_buffer(to_yaml(&pod)?); - let new_hash = pod.hash.clone(); - store.save_pod(&pod)?; - pretty_assert_eq!( - store.list_pod()?, - vec![ + let pod_with_updated_command = Pod::new( + Some(annotation.clone()), + image, + str_to_vec("python new_run.py"), + input_spec, + output_dir, + output_spec, + exec_requirements, + gpu_requirements, + )?; + store.save_pod(&pod_with_updated_command)?; + assert_eq!( + HashSet::from_iter(store.list_pod()?), + HashSet::from([ ModelInfo { - name: Some(original_annotation.name.clone()), - version: Some(original_annotation.version.clone()), - hash: original_hash.clone(), + name: Some(annotation.name.clone()), + version: Some(annotation.version.clone()), + hash: pod.hash.clone(), }, ModelInfo { name: None, version: None, - hash: original_hash, + hash: pod.hash, }, ModelInfo { name: None, version: None, - hash: new_hash, + hash: pod_with_updated_command.hash, }, - ], + ]), "Pod list didn't return 3 expected entries." ); - pretty_assert_eq!( + assert_eq!( store .load_pod(&ModelID::Annotation( - original_annotation.name.clone(), - original_annotation.version.clone() + annotation.name.clone(), + annotation.version.clone() ))? .annotation, - Some(original_annotation), + Some(annotation), "Pod annotation unexpected." ); Ok(()) } + +#[test] +fn map_operator_basic() -> Result<()> { + let map_operator = + MapOperator::new(HashMap::from([("input_key".into(), "output_key".into())]))?; + + let (_, store) = get_store_fixtures(); + + println!("Saved map operator with hash: {}", &map_operator.hash); + store.save_map_operator(&map_operator)?; + + assert!(store.list_map_operator()?.contains(&map_operator.hash)); + + // Load and compare + assert_eq!( + &store.load_map_operator(&map_operator.hash)?, + &map_operator, + "Loaded map operator doesn't match." + ); + + // Delete and assert not found + store.delete_map_operator(&map_operator.hash)?; + assert!(!store.list_map_operator()?.contains(&map_operator.hash)); + + Ok(()) +} + +#[test] +fn pipeline_basic() -> Result<()> { + let pipeline = pipeline()?; + basic_test(&pipeline, &pipeline) +}