diff --git a/aw-query/src/datatype.rs b/aw-query/src/datatype.rs index c31355dd..dd1affd2 100644 --- a/aw-query/src/datatype.rs +++ b/aw-query/src/datatype.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; use std::fmt; +use std::str::FromStr as _; use super::functions; use super::QueryError; use aw_models::Event; -use aw_transform::classify::{RegexRule, Rule}; +use aw_transform::classify::{LogicalOperator, LogicalRule, RegexRule, Rule}; use serde::{Serialize, Serializer}; use serde_json::value::Value; @@ -297,50 +298,90 @@ impl TryFrom<&DataType> for Rule { )) } }; - if rtype == "none" { - Ok(Self::None) - } else if rtype == "regex" { - let regex_val = match obj.get("regex") { - Some(regex_val) => regex_val, - None => { - return Err(QueryError::InvalidFunctionParameters( - "regex rule is missing the 'regex' field".to_string(), - )) - } - }; - let regex_str = match regex_val { - DataType::String(s) => s, - _ => { - return Err(QueryError::InvalidFunctionParameters( - "the regex field of the regex rule is not a string".to_string(), - )) - } - }; - let ignore_case_val = match obj.get("ignore_case") { - Some(case_val) => case_val, - None => &DataType::Bool(false), - }; - let ignore_case = match ignore_case_val { - DataType::Bool(b) => b, - _ => { - return Err(QueryError::InvalidFunctionParameters( - "the ignore_case field of the regex rule is not a bool".to_string(), - )) - } - }; - let regex_rule = match RegexRule::new(regex_str, *ignore_case) { - Ok(regex_rule) => regex_rule, - Err(err) => { - return Err(QueryError::RegexCompileError(format!( - "Failed to compile regex string '{regex_str}': '{err:?}" - ))) - } - }; - Ok(Self::Regex(regex_rule)) - } else { - Err(QueryError::InvalidFunctionParameters(format!( + + match rtype.as_str() { + "none" => Ok(Self::None), + "or" | "and" => parse_logical_rule(obj, rtype), + "regex" => parse_regex_rule(obj), + _ => Err(QueryError::InvalidFunctionParameters(format!( "Unknown rule type '{rtype}'" - ))) + ))), } } } + +fn parse_logical_rule(obj: &HashMap, rtype: &String) -> Result { + let Some(rules) = obj.get("rules") else { + return Err(QueryError::InvalidFunctionParameters(format!( + "{} rule is missing the 'rules' field", + rtype + ))); + }; + + let rules = match rules { + DataType::List(rules) => rules + .iter() + .map(Rule::try_from) + .collect::, _>>()?, + _ => { + return Err(QueryError::InvalidFunctionParameters(format!( + "the rules field of the {} rule is not a list", + rtype + ))) + } + }; + + let operator = + LogicalOperator::from_str(rtype).map_err(QueryError::InvalidFunctionParameters)?; + + Ok(Rule::Logical(LogicalRule::new(rules, operator))) +} + +fn parse_regex_rule(obj: &HashMap) -> Result { + let regex_val = match obj.get("regex") { + Some(regex_val) => regex_val, + None => { + return Err(QueryError::InvalidFunctionParameters( + "regex rule is missing the 'regex' field".to_string(), + )) + } + }; + let regex_str = match regex_val { + DataType::String(s) => s, + _ => { + return Err(QueryError::InvalidFunctionParameters( + "the regex field of the regex rule is not a string".to_string(), + )) + } + }; + let ignore_case_val = match obj.get("ignore_case") { + Some(case_val) => case_val, + None => &DataType::Bool(false), + }; + let ignore_case = match ignore_case_val { + DataType::Bool(b) => b, + _ => { + return Err(QueryError::InvalidFunctionParameters( + "the ignore_case field of the regex rule is not a bool".to_string(), + )) + } + }; + let match_field = match obj.get("field") { + Some(DataType::String(v)) => Some(v.to_owned()), + None => None, + _ => { + return Err(QueryError::InvalidFunctionParameters( + "the `field` field of the regex rule is not a string".to_string(), + )) + } + }; + let regex_rule = match RegexRule::new(regex_str, *ignore_case, match_field) { + Ok(regex_rule) => regex_rule, + Err(err) => { + return Err(QueryError::RegexCompileError(format!( + "Failed to compile regex string '{regex_str}': '{err:?}" + ))) + } + }; + Ok(Rule::Regex(regex_rule)) +} diff --git a/aw-transform/src/classify.rs b/aw-transform/src/classify.rs index 114650d6..46fdcbcf 100644 --- a/aw-transform/src/classify.rs +++ b/aw-transform/src/classify.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + /// Transforms for classifying (tagging and categorizing) events. /// /// Based on code in aw_research: https://github.com/ActivityWatch/aw-research/blob/master/aw_research/classify.py @@ -6,6 +8,7 @@ use fancy_regex::Regex; pub enum Rule { None, + Logical(LogicalRule), Regex(RegexRule), } @@ -13,6 +16,7 @@ impl RuleTrait for Rule { fn matches(&self, event: &Event) -> bool { match self { Rule::None => false, + Rule::Logical(rule) => rule.matches(event), Rule::Regex(rule) => rule.matches(event), } } @@ -24,10 +28,15 @@ trait RuleTrait { pub struct RegexRule { regex: Regex, + field: Option, } impl RegexRule { - pub fn new(regex_str: &str, ignore_case: bool) -> Result { + pub fn new( + regex_str: &str, + ignore_case: bool, + field: Option, + ) -> Result { // can't use `RegexBuilder::case_insensitive` because it's not supported by fancy_regex, // so we need to prefix with `(?i)` to make it case insensitive. let regex = if ignore_case { @@ -37,7 +46,7 @@ impl RegexRule { Regex::new(regex_str)? }; - Ok(RegexRule { regex }) + Ok(RegexRule { regex, field }) } } @@ -50,15 +59,59 @@ impl RuleTrait for RegexRule { fn matches(&self, event: &Event) -> bool { event .data - .values() - .filter(|val| val.is_string()) - .any(|val| self.regex.is_match(val.as_str().unwrap()).unwrap()) + .iter() + .filter(|(field, val)| { + self.field.as_ref().map(|v| &v == field).unwrap_or(true) && val.is_string() + }) + .any(|(_, val)| self.regex.is_match(val.as_str().unwrap()).unwrap()) } } impl From for Rule { fn from(re: Regex) -> Self { - Rule::Regex(RegexRule { regex: re }) + Rule::Regex(RegexRule { + regex: re, + field: None, + }) + } +} + +pub enum LogicalOperator { + Or, + And, +} + +impl FromStr for LogicalOperator { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "or" => Ok(Self::Or), + "and" => Ok(Self::And), + _ => Err(format!("Invalid logical operator: {}", s)), + } + } +} + +pub struct LogicalRule { + rules: Vec, + operator: LogicalOperator, +} + +impl LogicalRule { + pub fn new(rules: Vec, operator: LogicalOperator) -> Self { + Self { rules, operator } + } +} + +impl RuleTrait for LogicalRule { + fn matches(&self, event: &Event) -> bool { + use LogicalOperator::{And, Or}; + + match self.operator { + Or => self.rules.iter().any(|rule| rule.matches(event)), + And => self.rules.iter().all(|rule| rule.matches(event)), + } } } @@ -135,7 +188,7 @@ fn test_rule() { .insert("nonono".into(), serde_json::json!("no match!")); let rule_from_regex = Rule::from(Regex::new("test").unwrap()); - let rule_from_new = Rule::Regex(RegexRule::new("test", false).unwrap()); + let rule_from_new = Rule::Regex(RegexRule::new("test", false, None).unwrap()); let rule_none = Rule::None; assert!(rule_from_regex.matches(&e_match)); assert!(rule_from_new.matches(&e_match));