diff --git a/assemblyline-models/src/datastore/tagging.rs b/assemblyline-models/src/datastore/tagging.rs index 1535faf..ee73894 100644 --- a/assemblyline-models/src/datastore/tagging.rs +++ b/assemblyline-models/src/datastore/tagging.rs @@ -46,7 +46,7 @@ impl From for TagValue { // MARK: Tag Processors #[derive(Debug)] -enum TagProcessor { +pub enum TagProcessor { // Generic strings String, Uppercase, @@ -935,7 +935,25 @@ fn network_tag_parsing() { assert_eq!(proc.apply(json!("www.google.com")), Err(json!("www.google.com"))); assert_eq!(proc.apply(json!("www.GooGle.com")), Err(json!("www.GooGle.com"))); assert_eq!(proc.apply(json!("172.0.0.1")), Ok(json!("172.0.0.1"))); + assert_eq!(proc.apply(json!("0.0.0.0")), Ok(json!("0.0.0.0"))); + assert_eq!(proc.apply(json!("127.0.0.0")), Ok(json!("127.0.0.0"))); + assert_eq!(proc.apply(json!("127.0.10.200")), Ok(json!("127.0.10.200"))); + assert_eq!(proc.apply(json!("255.255.255.255")), Ok(json!("255.255.255.255"))); assert_eq!(proc.apply(json!("1234:5678:9ABC:0000:0000:1234:5678:9abc")), Ok(json!("1234:5678:9ABC:0000:0000:1234:5678:9ABC"))); + // we don't want abbrivated, octal, hex, or padded ip addresses + assert_eq!(proc.apply(json!("172.1")), Err(json!("172.1"))); + assert_eq!(proc.apply(json!("256.0.0.1")), Err(json!("256.0.0.1"))); + assert_eq!(proc.apply(json!("0.256.0.1")), Err(json!("0.256.0.1"))); + assert_eq!(proc.apply(json!("0.0.256.1")), Err(json!("0.0.256.1"))); + assert_eq!(proc.apply(json!("0.0.0.256")), Err(json!("0.0.0.256"))); + assert_eq!(proc.apply(json!("172.0x1.0.1")), Err(json!("172.0x1.0.1"))); + assert_eq!(proc.apply(json!("172.01.0.1")), Err(json!("172.01.0.1"))); + assert_eq!(proc.apply(json!("172.1.0.00000000001")), Err(json!("172.1.0.00000000001"))); + assert_eq!(proc.apply(json!("0.0.0.")), Err(json!("0.0.0."))); + assert_eq!(proc.apply(json!("0.0.0.0.")), Err(json!("0.0.0.0."))); + assert_eq!(proc.apply(json!(".0.0.0")), Err(json!(".0.0.0"))); + assert_eq!(proc.apply(json!(".0.0.0.0")), Err(json!(".0.0.0.0"))); + let proc = TagProcessor::UNCPath; assert_eq!(proc.apply(json!("www.google.com")), Err(json!("www.google.com"))); diff --git a/assemblyline-models/src/types/strings.rs b/assemblyline-models/src/types/strings.rs index 9a33921..3db7dee 100644 --- a/assemblyline-models/src/types/strings.rs +++ b/assemblyline-models/src/types/strings.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ServiceName { impl<'de> Deserialize<'de> for ServiceName { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> + D: serde::Deserializer<'de> { struct Visitor {} @@ -81,17 +81,17 @@ impl From<&str> for ServiceName { } /// A string that maps to a keyword field in elasticsearch. -/// +/// /// This is the default behaviour for a String in a mapped struct, the only reason /// to use this over a standard String is cases where the 'mapping' field has been overwritten /// by a container and the more explicit 'mapping' this provided is needed to reassert /// the keyword type. -/// +/// /// Example: /// #[metadata(store=false, mapping="flattenedobject")] /// pub safelisted_tags: HashMap>, -/// -/// In that example, if the inner Keyword was String the entire HashMap would have its +/// +/// In that example, if the inner Keyword was String the entire HashMap would have its /// mapping set to 'flattenedobject', the inner Keyword more explicitly overrides this. #[derive(Debug, Serialize, Deserialize, Described, Clone, PartialEq, Eq, PartialOrd, Ord)] #[metadata_type(ElasticMeta)] @@ -238,8 +238,8 @@ impl Text { #[derive(Debug, thiserror::Error)] #[error("Could not process {original} as a {name}: {error}")] pub struct ValidationError { - original: String, - name: &'static str, + original: String, + name: &'static str, error: String } @@ -371,11 +371,11 @@ pub fn check_domain(data: &str) -> Result { return Err(DomainError::IlligalCharacter) // raise ValueError(f"[{self.name or self.parent_name}] '{segment}' in '{value}' " // f"includes a Unicode character that can not be normalized to '{segment_norm}'.") - } + } normalized_parts.push(segment_norm); } } - + let mut domain = normalized_parts.join("."); static VALIDATION_REGEX: LazyLock = LazyLock::new(||{ @@ -391,7 +391,7 @@ pub fn check_domain(data: &str) -> Result { } if domain.contains("@") { - return Err(DomainError::IlligalCharacter) + return Err(DomainError::IlligalCharacter) } if let Some((_, tld)) = domain.rsplit_once(".") { @@ -411,7 +411,7 @@ pub fn check_domain(data: &str) -> Result { } Err(DomainError::InvalidTLD) -} +} // def is_valid_domain(domain: str) -> bool: // if "@" in domain: @@ -461,14 +461,14 @@ fn system_local_tld() -> Vec { fn find_top_level_domains() -> &'static HashSet { static TLDS: LazyLock> = LazyLock::new(|| { use super::net_static::TLDS_ALPHA_BY_DOMAIN; - let mut combined_tlds = HashSet::::new(); + let mut combined_tlds = HashSet::::new(); combined_tlds.extend(TLDS_ALPHA_BY_DOMAIN.iter().map(|s|s.to_string())); for d in TLDS_SPECIAL_BY_DOMAIN { if !d.contains(".") { combined_tlds.insert(d.to_owned()); } - } + } for tld in system_local_tld() { let tld = tld.trim_matches('.').to_uppercase(); @@ -497,11 +497,11 @@ pub type Domain = ValidatedString; #[test] fn internationalized_domains() { - assert_eq!(check_domain("ουτοπία.δπθ.gr").unwrap(), "ουτοπία.δπθ.gr"); + assert_eq!(check_domain("ουτοπία.δπθ.gr").unwrap(), "ουτοπία.δπθ.gr"); assert_eq!(check_domain("xn--kxae4bafwg.xn--pxaix.gr").unwrap(), "ουτοπία.δπθ.gr"); - assert_eq!(check_domain("site.XN--W4RS40L").unwrap(), "site.嘉里"); - assert!(check_domain("ουτοπία.δπθ.g").is_err()); - assert!(check_domain("ουτοπία..gr").is_err()); + assert_eq!(check_domain("site.XN--W4RS40L").unwrap(), "site.嘉里"); + assert!(check_domain("ουτοπία.δπθ.g").is_err()); + assert!(check_domain("ουτοπία..gr").is_err()); assert!(check_domain("xn--kxae4bafwg.xn--pxaix.g").is_err()); assert!(check_domain("xn--kxae4bafwg.xn--xaix.gr").is_err()); } @@ -596,8 +596,8 @@ pub type Uri = ValidatedString; // pub type UriPath = String; // MARK: IP +const IPV4_REGEX: &str = r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])"; -const IPV4_REGEX: &str = r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"; const IPV6_REGEX: &str = concat!( r"(?:(?:[0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|", r"(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|", diff --git a/assemblyline-server/src/constants.rs b/assemblyline-server/src/constants.rs index 1e30fa6..3ee82c4 100644 --- a/assemblyline-server/src/constants.rs +++ b/assemblyline-server/src/constants.rs @@ -19,6 +19,7 @@ pub(crate) const DISPATCH_TASK_HASH: &str = "dispatch-active-submissions"; pub(crate) const SCALER_TIMEOUT_QUEUE: &str = "scaler-timeout-queue"; pub(crate) const SERVICE_STATE_HASH: &str = "service-stasis-table"; pub(crate) const SERVICE_QUEUE_PREFIX: &str = "service-queue-"; +pub(crate) const SERVICE_API_KEY_HASH: &str = "dynamic-service-keys"; /// Take the name of a service, and provide the queue name to send tasks to that service. pub fn service_queue_name(service: &str) -> String { @@ -130,3 +131,12 @@ pub const SERVICE_STAGE_KEY: &str = "service-stage"; // 'critical': 500, // 'high': 100, // } + + +/// An entry stored in redis in the SERVICE_API_KEY_HASH hash +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ServiceApiKeyConfig { + pub key: String, + pub allow_registry_writing: bool, + pub expiry: chrono::DateTime +} \ No newline at end of file diff --git a/assemblyline-server/src/identify/default.magic b/assemblyline-server/src/identify/default.magic index 1600239..01dc906 100644 --- a/assemblyline-server/src/identify/default.magic +++ b/assemblyline-server/src/identify/default.magic @@ -24,6 +24,10 @@ # Open XML files with Microsoft Word 0 string >0 search/0x100 = custom: document/office/word +# MSBuild Project Files +0 string +>0 search/0x40 \>&0 search/0x40 http://schemas.microsoft.com/developer/msbuild custom: code/xml/msbuild # VBE files 0 string #@~^ >&0 regex/9 \^[^=]{6}== custom: code/vbe diff --git a/assemblyline-server/src/postprocessing/parsing.rs b/assemblyline-server/src/postprocessing/parsing.rs index 46f37f4..f9d4ded 100644 --- a/assemblyline-server/src/postprocessing/parsing.rs +++ b/assemblyline-server/src/postprocessing/parsing.rs @@ -6,12 +6,12 @@ use chrono::{DateTime, Utc, Duration, NaiveDate, NaiveDateTime, NaiveTime}; use nom::{IResult, Parser}; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, escaped_transform, is_not, take_while1, tag_no_case, is_a}; -use nom::character::complete::{multispace0, alphanumeric1, one_of}; -use nom::combinator::{map, map_opt, map_res, opt, value}; +use nom::character::complete::{alphanumeric1, multispace0, multispace1, one_of}; +use nom::combinator::{eof, map, map_opt, map_res, opt, peek, value}; use nom::error::ParseError; -use nom::multi::{separated_list1, count, many1}; +use nom::multi::{count, many_till, many1, separated_list1}; use nom::number::complete::double; -use nom::sequence::{delimited, pair}; +use nom::sequence::{delimited, pair, terminated}; use super::search::{Query, PrefixOperator, StringQuery, FieldQuery, RangeBound, RangeTerm, RangeQuery, DateExpression, DateUnit, NumberQuery}; use super::ParsingError; @@ -86,9 +86,9 @@ fn not_operator(input: &str) -> IResult<&str, ()> { fn atom(input: &str) -> IResult<&str, Query> { // println!("atom: {input}"); alt(( - delimited(ws(tag("(")), expression, ws(tag(")"))), + delimited(ws(tag("(")), expression, ws(tag(")"))), exists, - field, + field, term )).parse(input) } @@ -133,7 +133,7 @@ fn prefix_operator(input: &str) -> IResult<&str, PrefixOperator> { fn is_special(value: char) -> bool { matches!(value, '_' | '-') } -fn simple_term(input: &str) -> IResult<&str, String> { +fn primitive_simple_term(input: &str) -> IResult<&str, String> { // println!("simple_term: {input}"); map_res(escaped_transform( alt((alphanumeric1, take_while1(is_special))), @@ -170,18 +170,33 @@ fn simple_term(input: &str) -> IResult<&str, String> { }).parse(input) } +fn simple_term(input: &str) -> IResult<&str, String> { + // println!("simple_term: {input}"); + terminated(primitive_simple_term, end_term).parse(input) +} + fn pattern_term(input: &str) -> IResult<&str, regex::Regex> { - // println!("pattern_term: {input}"); - map_res(many1(alt(( - map(tag("*"), |_|{String::from(".*")}), - map(tag("?"), |_|{String::from(".")}), - map(simple_term, |row|{regex::escape(&row)}), - ))), |parts|{ + map_res(many_till( + alt(( + map(tag("*"), |_|{String::from(".*")}), + map(tag("?"), |_|{String::from(".")}), + map(primitive_simple_term, |row|{regex::escape(&row)}), + )), + end_term + ), |(parts, _)|{ let pattern = parts.join(""); regex::Regex::new(&pattern) }).parse(input) } +fn end_term(input: &str) -> IResult<&str, ()> { + alt(( + map(peek(tag(")")), |_| ()), + map(eof, |_|()), + map(multispace1, |_| ()), + )).parse(input) +} + // phrase_term: ESCAPED_STRING fn phrase_term(input: &str) -> IResult<&str, String> { quoted_string(input) diff --git a/assemblyline-server/src/postprocessing/search.rs b/assemblyline-server/src/postprocessing/search.rs index 29cdb30..c24749c 100644 --- a/assemblyline-server/src/postprocessing/search.rs +++ b/assemblyline-server/src/postprocessing/search.rs @@ -52,7 +52,7 @@ fn get_full_text_field_names() -> &'static Vec> { struct_metadata::Kind::Mapping(_, _) => {}, _ => {} } - break + break } } @@ -673,7 +673,7 @@ fn check_field_type(_root: &str, tail: &[String], kind: &struct_metadata::Kind false, 1 => true, _ => check_field_type(&tail[1], &tail[2..], &inner.kind) - } + } }, // Recurse into the inner type for these @@ -683,7 +683,7 @@ fn check_field_type(_root: &str, tail: &[String], kind: &struct_metadata::Kind { tail.is_empty() diff --git a/assemblyline-server/src/postprocessing/tests.rs b/assemblyline-server/src/postprocessing/tests.rs index 7669051..d3454a8 100644 --- a/assemblyline-server/src/postprocessing/tests.rs +++ b/assemblyline-server/src/postprocessing/tests.rs @@ -50,6 +50,9 @@ fn test_simple_filters() { assert!(!fltr.test(&sub).unwrap()); // Try a prefix operator and wildcard matches + let _ = parse("max_score: >100 AND NOT results: virus").unwrap(); + let _ = parse("max_score: >100 AND NOT results: virus*").unwrap(); + let _ = parse("max_score: >100 AND NOT results: *virus").unwrap(); let fltr = parse("max_score: >100 AND NOT results: *virus*").unwrap(); assert_eq!(fltr.cache_safe(), CacheAvailabilityStatus::ErrorUsesForbiddenFields(vec!["results".to_owned()])); @@ -96,6 +99,10 @@ fn test_simple_filters() { assert_eq!(fltr.cache_safe(), CacheAvailabilityStatus::Ok); assert!(fltr.test(&sub).unwrap()); + let fltr = parse("metadata.stuff: big-*").unwrap(); + assert_eq!(fltr.cache_safe(), CacheAvailabilityStatus::Ok); + assert!(fltr.test(&sub).unwrap()); + let fltr = parse("metadata.stuff: big\\-bad").unwrap(); assert_eq!(fltr.cache_safe(), CacheAvailabilityStatus::Ok); assert!(fltr.test(&sub).unwrap()); @@ -299,7 +306,7 @@ async fn test_hook() { use assemblyline_models::datastore::submission::Submission; let (port, mut hits) = run_server().await; - + let action = PostprocessAction{ enabled: true, run_on_completed: true, @@ -307,7 +314,7 @@ async fn test_hook() { webhook: Some(Webhook { uri: format!("http://localhost:{port}"), headers: vec![NamedValue{ - name: "care-of".to_string(), + name: "care-of".to_string(), value: "assemblyline".to_string() }], password: None, @@ -338,14 +345,14 @@ async fn test_hook() { { let mut sub: Submission = rand::rng().random(); sub.metadata.insert("ok".to_string(), "bad".into()); - worker.process(&sub, Default::default(), false).await.unwrap(); + worker.process(&sub, Default::default(), false).await.unwrap(); } { let mut sub: Submission = rand::rng().random(); sub.metadata.insert("ok".to_string(), "good".into()); sub.metadata.insert("do_hello".to_string(), "yes".into()); - worker.process(&sub, Default::default(), false).await.unwrap(); + worker.process(&sub, Default::default(), false).await.unwrap(); } let (headers, body) = tokio::time::timeout(std::time::Duration::from_secs(3), hits.recv()).await.unwrap().unwrap(); @@ -360,11 +367,11 @@ async fn test_hook() { #[test] fn test_webhook_match() { let webhook_first = json!({ - "uri": "http://api.interface.website" + "uri": "http://api.interface.website" }); let webhook_second = json!({ - "uri": "http://api.interface.website", + "uri": "http://api.interface.website", "headers": [{ "name": "APIKEY", "value": "1111111111111111111111111111111111111111111111111111111111111111" diff --git a/assemblyline-server/src/service_api/helpers/auth.rs b/assemblyline-server/src/service_api/helpers/auth.rs index 27f72cc..6a8e7f9 100644 --- a/assemblyline-server/src/service_api/helpers/auth.rs +++ b/assemblyline-server/src/service_api/helpers/auth.rs @@ -1,26 +1,98 @@ +use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; +use std::time::{Duration, Instant}; use assemblyline_models::types::ServiceName; use itertools::Itertools; -use log::{debug, warn}; +use log::{debug, error, warn}; +use parking_lot::Mutex; use poem::http::HeaderName; use poem::IntoResponse; use poem::{Endpoint, Middleware, Request, Response, Result, http::StatusCode}; +use redis_objects::{Hashmap, RedisObjects}; use crate::Core; +use crate::constants::{SERVICE_API_KEY_HASH, ServiceApiKeyConfig}; use super::make_empty_api_error; +/// Don't re-check an api key within 30 seconds. +const KEY_CACHE_TIMEOUT: Duration = Duration::from_secs(30); +#[derive(Clone)] +pub struct ApiKeyLoader { + /// Redis map containing runtime configured service api keys + table: Hashmap, + /// An in-memory cache for api keys to reduce remote calls + cache: Arc>>, +} + +impl ApiKeyLoader { + pub fn new(redis_persistant: Arc) -> Self { + let table = redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None); + + // Each time a service server starts clear out the api key table + // this doesn't need to be done often + let cleanup_table = table.clone(); + tokio::spawn(async move { + if let Err(err) = Self::cleanup(cleanup_table).await { + error!("Crash in service api key cleanup: {err:?}"); + } + }); + + Self { + table, + cache: Arc::new(Mutex::new(Default::default())), + } + } + + async fn cleanup(table: Hashmap) -> anyhow::Result<()> { + let existing = table.items().await?; + for (key, config) in existing { + if config.expiry < chrono::Utc::now() { + table.pop(&key).await?; + } + } + Ok(()) + } + pub async fn check_key(&self, key: &str) -> anyhow::Result> { + { + let mut cache = self.cache.lock(); + cache.retain(|_, (created, config)| { + created.elapsed() < KEY_CACHE_TIMEOUT || config.expiry > chrono::Utc::now() + }); + if let Some((_, config)) = cache.get(key) { + return Ok(Some(config.clone())) + } + } + + if let Some(row) = self.table.get(key).await? { + if row.expiry > chrono::Utc::now() { + let mut cache = self.cache.lock(); + cache.insert(key.to_owned(), (Instant::now(), row.clone())); + return Ok(Some(row)) + } else { + self.table.pop(&row.key).await?; + } + } + Ok(None) + } +} + +#[derive(Clone)] pub struct ServiceAuth { + /// Reference to the system asset pool core: Arc, + /// A staticly defined api key that worker services use when contacting the api auth_key: String, + /// A caching wrapper around the redis store for dynamic api keys + key_loader: ApiKeyLoader, } impl ServiceAuth { - pub fn new(core: Arc) -> Self { + pub fn new(core: Arc, key_loader: ApiKeyLoader) -> Self { let auth_key = match std::env::var("SERVICE_API_KEY"){ Ok(key) => key, Err(_) => { @@ -31,6 +103,7 @@ impl ServiceAuth { Self { core, auth_key, + key_loader, } } } @@ -42,6 +115,7 @@ impl Middleware for ServiceAuth { ServiceAuthImpl{ core: self.core.clone(), auth_key: self.auth_key.clone(), + loader: self.key_loader.clone(), endpoint: ep } } @@ -50,6 +124,7 @@ impl Middleware for ServiceAuth { pub struct ServiceAuthImpl { core: Arc, auth_key: String, + loader: ApiKeyLoader, endpoint: E, } @@ -78,25 +153,29 @@ impl Endpoint for ServiceAuthImpl { None => return Err(make_empty_api_error(StatusCode::BAD_REQUEST, "missing required key X-APIKEY")), }; - if self.auth_key != apikey { + let key_info = if self.auth_key == apikey { + ServiceApiKeyConfig { key: self.auth_key.clone(), allow_registry_writing: false, expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) } + } else if let Some(key_info) = self.loader.check_key(apikey).await? { + key_info + } else { let client_id = req.header("CONTAINER-ID").unwrap_or("Unknown Client"); let header_dump = req.headers().iter().map(|(k, v)| format!("{k}={v:?}")).join("; "); warn!("Client [{client_id}] provided wrong api key [{apikey}] headers: {header_dump}"); return Err(make_empty_api_error(StatusCode::UNAUTHORIZED, "Unauthorized access denied")); - } + }; let client_info = match ClientInfo::new(&req) { Ok(info) => info, Err(key) => { let client_id = req.header("CONTAINER-ID").unwrap_or("Unknown Client"); let header_dump = req.headers().iter().map(|(k, v)| format!("{k}={v:?}")).join("; "); - debug!("Client [{client_id}] missing required header [{key}] headers: {header_dump}"); + debug!("Client [{client_id}] missing required header [{key}] headers: {header_dump}"); return Err(make_empty_api_error(StatusCode::BAD_REQUEST, &format!("missing required key {key}"))) }, }; req.extensions_mut().insert(client_info); req.extensions_mut().insert(self.core.clone()); - + req.extensions_mut().insert(key_info); // if config.core.metrics.apm_server.server_url is not None { // elasticapm.set_user_context(username=client_info['service_name']) diff --git a/assemblyline-server/src/service_api/helpers/tasking.rs b/assemblyline-server/src/service_api/helpers/tasking.rs index c10b2cc..9f883b8 100644 --- a/assemblyline-server/src/service_api/helpers/tasking.rs +++ b/assemblyline-server/src/service_api/helpers/tasking.rs @@ -204,10 +204,35 @@ impl TaskingClient { Ok(()) } - pub async fn register_service(&self, mut service_data: JsonMap, log_prefix: &str) -> Result { - debug!("Registring service: {:?}", service_data.get("name")); - let mut keep_alive = true; + pub async fn compare_service_info(&self, service_data: JsonMap, log_prefix: &str) -> Result { + // normalize the service_data the same as we would during registration + let (service, _heuristics) = self.normalize_service_info(service_data)?; + + // get the currently running configuration for the service + let service_config = match self.datastore.get_service_with_delta(&service.name, None).await? { + Some(config) => config, + None => { + error!("{log_prefix}Service attempted to register without write permissions: {}/{}", + service.name, service.version); + return Err(RegisterError::Permission) + } + }; + + // Validate that this call is for the right service version + if service_config.version != service.version { + error!("{log_prefix}Service [{}] attempted to register with the wrong version: {} (expected {})", + service.name, service.version, service_config.version); + return Err(RegisterError::Permission) + } + Ok(RegisterResponse{ + keep_alive: true, + new_heuristics: vec![], + service_config + }) + } + + pub fn normalize_service_info(&self, mut service_data: JsonMap) -> Result<(Service, Option), RegisterError> { // Initialize the classification strings if !service_data.contains_key("classification") { service_data.insert("classification".to_string(), json!(self.classification_engine.unrestricted())); @@ -257,6 +282,15 @@ impl TaskingClient { // Fix service version, we don't need to see the stable label service.version = service.version.replace("stable", ""); + Ok((service, heuristics)) + } + + pub async fn register_service(&self, service_data: JsonMap, log_prefix: &str) -> Result { + debug!("Registring service: {:?}", service_data.get("name")); + let mut keep_alive = true; + + let (service, heuristics) = self.normalize_service_info(service_data)?; + // Save service if it doesn't already exist let key = format!("{}_{}", service.name, service.version); debug!("Registering service: storing version manifest"); @@ -448,6 +482,8 @@ fn fix_docker_config(docker_config: &mut JsonMap, registry_type: &Value) -> serd pub enum RegisterError { #[error("Could not complete json coversion: {0}")] Formatting(String), + #[error("The operation requested required permissions this api key does not have")] + Permission, #[error("{0}")] BadHeuristic(String), #[error("Service was removed during registration.")] @@ -460,11 +496,16 @@ impl RegisterError { pub fn is_input_error(&self) -> bool { match self { RegisterError::Formatting(_) => true, + RegisterError::Permission => false, RegisterError::BadHeuristic(_) => true, RegisterError::ServiceRemoved => false, RegisterError::Other(_) => false, } } + + pub fn is_permission(&self) -> bool { + matches!(self, RegisterError::Permission) + } } impl From for RegisterError { diff --git a/assemblyline-server/src/service_api/mod.rs b/assemblyline-server/src/service_api/mod.rs index fe95256..3a87b77 100644 --- a/assemblyline-server/src/service_api/mod.rs +++ b/assemblyline-server/src/service_api/mod.rs @@ -7,6 +7,7 @@ use poem::{Endpoint, EndpointExt, Route, Server}; use crate::logging::LoggerMiddleware; use crate::Core; +use crate::service_api::helpers::auth::{ApiKeyLoader, ServiceAuth}; pub mod helpers; pub mod v1; @@ -15,14 +16,16 @@ pub (crate) mod tests; pub async fn api(core: Arc) -> Result { let tasking_client = Arc::new(TaskingClient::new(&core).await?); + let keys = ApiKeyLoader::new(core.redis_persistant.clone()); + let auth = ServiceAuth::new(core.clone(), keys); Ok( Route::new() - .nest("/api/v1/badlist", v1::badlist::api(core.clone())) - .nest("/api/v1/file", v1::file::api(core.clone())) - .nest("/api/v1/service", v1::service::api(core.clone())) - .nest("/api/v1/task", v1::task::api(core.clone())) - .nest("/api/v1/safelist", v1::safelist::api(core.clone())) + .nest("/api/v1/badlist", v1::badlist::api(core.clone(), auth.clone())) + .nest("/api/v1/file", v1::file::api(auth.clone())) + .nest("/api/v1/service", v1::service::api(auth.clone())) + .nest("/api/v1/task", v1::task::api(auth.clone())) + .nest("/api/v1/safelist", v1::safelist::api(core.clone(), auth)) .nest("/healthz", v1::health::api(core.clone())) .data(tasking_client) .with(LoggerMiddleware) diff --git a/assemblyline-server/src/service_api/tests/service.rs b/assemblyline-server/src/service_api/tests/service.rs index 8fb0b45..4a4345a 100644 --- a/assemblyline-server/src/service_api/tests/service.rs +++ b/assemblyline-server/src/service_api/tests/service.rs @@ -4,6 +4,7 @@ use assemblyline_models::datastore::heuristic::Heuristic; use assemblyline_models::types::ExpandingClassification; use reqwest::header::{HeaderMap, HeaderValue}; +use crate::constants::{SERVICE_API_KEY_HASH, ServiceApiKeyConfig}; use crate::service_api::helpers::APIResponse; use crate::service_api::tests::{build_service, empty_delta}; use crate::service_api::v1::service::RegisterResponse; @@ -13,7 +14,7 @@ use super::{setup, AUTH_KEY, random_hash}; fn headers() -> HeaderMap { [ ("Container-Id", random_hash(12)), - ("X-APIKey", AUTH_KEY.to_owned()), + ("X-Apikey", AUTH_KEY.to_owned()), ("Service-Tool-Version", random_hash(64)), ("X-Forwarded-For", "127.0.0.1".to_owned()), ].into_iter() @@ -42,16 +43,39 @@ async fn test_register_service_auth_fail() { let service = build_service(); let service_delta = empty_delta(&service); + + // try connecting using the pre configured (read only) api key against a service that doesn't exist + { + let mut headers = headers(); + headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); + headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); + + let result = client.post(format!("{address}/api/v1/service/register/")).headers(headers).json(&service).send().await.unwrap(); + assert_eq!(result.status().as_u16(), 401); + } + + // register the service via backend, and see the same query succeed core.datastore.service.save(&service.key(), &service, None, None).await.unwrap(); core.datastore.service_delta.save(&service.name, &service_delta, None, None).await.unwrap(); - - let mut headers = headers(); - headers.insert("X-APIKEY", HeaderValue::from_static("10")); - headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); - headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); + { + let mut headers = headers(); + headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); + headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); - let result = client.post(format!("{address}/api/v1/service/register/")).headers(headers).json(&service).send().await.unwrap(); - assert_eq!(result.status().as_u16(), 401); + let result = client.post(format!("{address}/api/v1/service/register/")).headers(headers).json(&service).send().await.unwrap(); + assert_eq!(result.status().as_u16(), 200); + } + + // try connecting using an obviously wrong API key to repeat that query that just succeeded + { + let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_static("10")); + headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); + headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); + + let result = client.post(format!("{address}/api/v1/service/register/")).headers(headers).json(&service).send().await.unwrap(); + assert_eq!(result.status().as_u16(), 401); + } } #[tokio::test] @@ -83,7 +107,7 @@ async fn test_register_bad_service() { let (client, _core, _guard, address) = setup(headers()).await; let mut service = build_service(); - + let mut headers = headers(); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); @@ -98,10 +122,18 @@ async fn test_register_bad_service() { async fn test_register_new_service() { let (client, core, _guard, address) = setup(headers()).await; + let key = rand::random::().to_string(); + core.redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None).add(&key, &ServiceApiKeyConfig { + key: key.clone(), + allow_registry_writing: true, + expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) + }).await.unwrap(); + let service = build_service(); let service_delta = empty_delta(&service); let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_str(&key).unwrap()); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); @@ -124,6 +156,13 @@ async fn test_register_new_service() { async fn test_register_new_service_version() { let (client, core, _guard, address) = setup(headers()).await; + let key = rand::random::().to_string(); + core.redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None).add(&key, &ServiceApiKeyConfig { + key: key.clone(), + allow_registry_writing: true, + expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) + }).await.unwrap(); + let mut service = build_service(); let service_delta = empty_delta(&service); @@ -134,6 +173,7 @@ async fn test_register_new_service_version() { service.version = "101".to_string(); let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_str(&key).unwrap()); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); @@ -157,6 +197,13 @@ async fn test_register_new_service_version() { async fn test_register_new_heuristics() { let (client, core, _guard, address) = setup(headers()).await; + let key = rand::random::().to_string(); + core.redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None).add(&key, &ServiceApiKeyConfig { + key: key.clone(), + allow_registry_writing: true, + expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) + }).await.unwrap(); + let service = build_service(); let service_delta = empty_delta(&service); @@ -171,6 +218,7 @@ async fn test_register_new_heuristics() { let heur_id = format!("{}.{}", service.name.to_uppercase(), new_heuristic.heur_id); let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_str(&key).unwrap()); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); @@ -193,6 +241,13 @@ async fn test_register_new_heuristics() { async fn test_register_existing_heuristics() { let (client, core, _guard, address) = setup(headers()).await; + let key = rand::random::().to_string(); + core.redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None).add(&key, &ServiceApiKeyConfig { + key: key.clone(), + allow_registry_writing: true, + expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) + }).await.unwrap(); + let service = build_service(); let service_delta = empty_delta(&service); @@ -213,6 +268,7 @@ async fn test_register_existing_heuristics() { new_heuristic.heur_id = base_id; let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_str(&key).unwrap()); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); @@ -234,6 +290,14 @@ async fn test_register_existing_heuristics() { async fn test_register_bad_heuristics() { let (client, core, _guard, address) = setup(headers()).await; + let key = rand::random::().to_string(); + core.redis_persistant.hashmap(SERVICE_API_KEY_HASH.to_owned(), None).add(&key, &ServiceApiKeyConfig { + key: key.clone(), + allow_registry_writing: true, + expiry: chrono::Utc::now() + chrono::TimeDelta::hours(1) + }).await.unwrap(); + + let service = build_service(); let service_delta = empty_delta(&service); @@ -248,6 +312,7 @@ async fn test_register_bad_heuristics() { service_request.get_mut("heuristics").unwrap().as_array_mut().unwrap().get_mut(0).unwrap().as_object_mut().unwrap().insert("description".to_string(), serde_json::Value::Null); let mut headers = headers(); + headers.insert("X-Apikey", HeaderValue::from_str(&key).unwrap()); headers.insert("Service-Name", HeaderValue::from_str(&service.name).unwrap()); headers.insert("Service-Version", HeaderValue::from_str(&service.version).unwrap()); diff --git a/assemblyline-server/src/service_api/v1/badlist.rs b/assemblyline-server/src/service_api/v1/badlist.rs index c24023f..d03a85f 100644 --- a/assemblyline-server/src/service_api/v1/badlist.rs +++ b/assemblyline-server/src/service_api/v1/badlist.rs @@ -17,14 +17,14 @@ const EMPTY: &[(); 0] = &[]; // SUB_API = 'badlist' // badlist_api = make_subapi_blueprint(SUB_API, api_version=1) // badlist_api._doc = "Query badlisted hashes" -pub fn api(core: Arc) -> impl Endpoint { +pub fn api(core: Arc, auth: ServiceAuth) -> impl Endpoint { Route::new() .at("/ssdeep", post(similar_ssdeep)) .at("/tlsh", post(similar_tlsh)) .at("/tags", post(tags_exists)) .at("/:qhash", get(exists)) .data(Arc::new(BadlistClient::new(core.datastore.clone(), core.config.clone(), core.classification_parser.clone()))) - .with(ServiceAuth::new(core)) + .with(auth) } @@ -55,21 +55,21 @@ async fn exists(Path(qhash): Path, client: Data<&Arc>) -> } /// Check if a file with a similar SSDeep exists. -/// +/// /// Variables: /// None -/// +/// /// Arguments: /// None -/// +/// /// Data Block: /// { /// ssdeep : value => Hash to check /// } -/// +/// /// API call example: /// GET /api/v1/badlist/ssdeep/ -/// +/// /// Result example: /// #[handler] @@ -90,21 +90,21 @@ pub struct SimilarSSDeepRequest { } /// Check if a file with a similar TLSH exists. -/// +/// /// Variables: /// None -/// +/// /// Arguments: /// None -/// +/// /// Data Block: /// { /// tlsh : value => Hash to check /// } -/// +/// /// API call example: /// GET /api/v1/badlist/tlsh/ -/// +/// /// Result example: /// #[handler] @@ -126,22 +126,22 @@ pub struct SimilarTlshRequest { /// Check if the provided tags exists in the badlist -/// +/// /// Variables: /// None -/// +/// /// Arguments: /// None -/// +/// /// Data Block: /// { # Dictionary of types -> values to check if exists /// "network.dynamic.domain": [...], /// "network.static.ip": [...] /// } -/// +/// /// API call example: /// GET /api/v1/badlist/tags/ -/// +/// /// Result example: /// [ # List of existing objecs /// , @@ -149,7 +149,7 @@ pub struct SimilarTlshRequest { /// ] #[handler] async fn tags_exists( - Json(data): Json>>, + Json(data): Json>>, client: Data<&Arc> ) -> Result { match client.exists_tags(data).await { diff --git a/assemblyline-server/src/service_api/v1/file.rs b/assemblyline-server/src/service_api/v1/file.rs index b8cae85..30dcdf2 100644 --- a/assemblyline-server/src/service_api/v1/file.rs +++ b/assemblyline-server/src/service_api/v1/file.rs @@ -31,11 +31,11 @@ use crate::Core; // SUB_API = 'file' // file_api = make_subapi_blueprint(SUB_API, api_version=1) // file_api._doc = "Perform operations on file" -pub fn api(core: Arc) -> impl Endpoint { +pub fn api(auth: ServiceAuth) -> impl Endpoint { Route::new() .at("/:sha256", get(download_file)) .at("/", put(upload_file)) - .with(ServiceAuth::new(core)) + .with(auth) } diff --git a/assemblyline-server/src/service_api/v1/safelist.rs b/assemblyline-server/src/service_api/v1/safelist.rs index d128386..8709a46 100644 --- a/assemblyline-server/src/service_api/v1/safelist.rs +++ b/assemblyline-server/src/service_api/v1/safelist.rs @@ -19,13 +19,13 @@ use crate::Core; // SUB_API = 'safelist' // safelist_api = make_subapi_blueprint(SUB_API, api_version=1) // safelist_api._doc = "Query safelisted hashes" -pub fn api(core: Arc) -> impl Endpoint { +pub fn api(core: Arc, auth: ServiceAuth) -> impl Endpoint { Route::new() .at("/signatures", get(get_safelisted_signatures)) .at("/:qhash", get(exists)) .at("/", get(get_safelisted_tags)) .data(Arc::new(SafelistClient::new(core.config.clone(), core.datastore.clone()))) - .with(ServiceAuth::new(core)) + .with(auth) } diff --git a/assemblyline-server/src/service_api/v1/service.rs b/assemblyline-server/src/service_api/v1/service.rs index c211dd6..117d9f8 100644 --- a/assemblyline-server/src/service_api/v1/service.rs +++ b/assemblyline-server/src/service_api/v1/service.rs @@ -14,6 +14,7 @@ use poem::web::{Data, Json}; use poem::{handler, put, Endpoint, EndpointExt, Result, Response, Route}; use serde::{Deserialize, Serialize}; +use crate::constants::ServiceApiKeyConfig; use crate::service_api::helpers::auth::{ClientInfo, ServiceAuth}; use crate::service_api::helpers::{make_api_response, make_empty_api_error}; use crate::service_api::helpers::tasking::TaskingClient; @@ -22,10 +23,10 @@ use crate::Core; // SUB_API = 'service' // service_api = make_subapi_blueprint(SUB_API, api_version=1) // service_api._doc = "Perform operations on service" -pub fn api(core: Arc) -> impl Endpoint { +pub fn api(auth: ServiceAuth) -> impl Endpoint { Route::new() .at("/register", put(register_service).post(register_service)) - .with(ServiceAuth::new(core)) + .with(auth) } @@ -39,9 +40,21 @@ pub fn api(core: Arc) -> impl Endpoint { /// 'service_config': < APPLIED SERVICE CONFIG > /// } #[handler] -async fn register_service(tasking: Data<&Arc>, Json(body): Json, client_info: Data<&ClientInfo>) -> Result { - match tasking.register_service(body, &format!("{} - ", client_info.client_id)).await { +async fn register_service( + tasking: Data<&Arc>, + Json(body): Json, + client_info: Data<&ClientInfo>, + Data(api_key): Data<&ServiceApiKeyConfig>, +) -> Result { + let outcome = if api_key.allow_registry_writing { + tasking.register_service(body, &format!("{} - ", client_info.client_id)).await + } else { + tasking.compare_service_info(body, &format!("{} - ", client_info.client_id)).await + }; + + match outcome { Ok(output) => Ok(make_api_response(output)), + Err(err) if err.is_permission() => Err(make_empty_api_error(StatusCode::UNAUTHORIZED, &err.to_string())), Err(err) if err.is_input_error() => Err(make_empty_api_error(StatusCode::BAD_REQUEST, &err.to_string())), Err(err) => Err(make_empty_api_error(StatusCode::INTERNAL_SERVER_ERROR, &err.to_string())) } @@ -49,7 +62,7 @@ async fn register_service(tasking: Data<&Arc>, Json(body): Json, - pub service_config: Service + pub keep_alive: bool, + pub new_heuristics: Vec, + pub service_config: Service } \ No newline at end of file diff --git a/assemblyline-server/src/service_api/v1/task.rs b/assemblyline-server/src/service_api/v1/task.rs index 41f7ff8..bccc42c 100644 --- a/assemblyline-server/src/service_api/v1/task.rs +++ b/assemblyline-server/src/service_api/v1/task.rs @@ -23,7 +23,6 @@ use serde_json::json; use crate::service_api::helpers::auth::{ClientInfo, ServiceAuth}; use crate::service_api::helpers::{make_api_error, make_api_response, make_empty_api_error}; use crate::service_api::helpers::tasking::{MalformedResult, ServiceMissing, TaskingClient, timestamp}; -use crate::Core; use super::require_header; @@ -33,10 +32,10 @@ const EXTRA_STATUS_TIME: Duration = Duration::from_secs(1); // SUB_API = 'task' // task_api = make_subapi_blueprint(SUB_API, api_version=1) // task_api._doc = "Perform operations on service tasks" -pub fn api(core: Arc) -> impl Endpoint { +pub fn api(auth: ServiceAuth) -> impl Endpoint { Route::new() .at("/", get(get_task).post(task_finished)) - .with(ServiceAuth::new(core)) + .with(auth) }