Skip to content

Commit c8209c7

Browse files
committed
perf: optimize rust validation sidecar matching
Signed-off-by: lucarlig <luca.carlig@ibm.com>
1 parent 188d93b commit c8209c7

File tree

1 file changed

+36
-24
lines changed
  • tools_rust/validation_middleware_sidecar/src

1 file changed

+36
-24
lines changed

tools_rust/validation_middleware_sidecar/src/lib.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ use std::sync::{Arc, Mutex};
77

88
struct CompiledValidator {
99
max_param_length: usize,
10-
dangerous_patterns: Vec<Regex>,
10+
dangerous_pattern: Option<Regex>,
11+
}
12+
13+
enum ValidationFailure {
14+
MaxLength,
15+
DangerousPattern,
1116
}
1217

1318
static VALIDATOR_CACHE: Lazy<Mutex<HashMap<String, Arc<CompiledValidator>>>> =
@@ -32,17 +37,23 @@ fn get_validator(
3237
return Ok(existing);
3338
}
3439

35-
let compiled_patterns = dangerous_patterns
36-
.iter()
37-
.map(|pattern| {
38-
Regex::new(pattern)
39-
.map_err(|error| PyErr::new::<pyo3::exceptions::PyValueError, _>(error.to_string()))
40-
})
41-
.collect::<PyResult<Vec<_>>>()?;
40+
let combined_pattern =
41+
if dangerous_patterns.is_empty() {
42+
None
43+
} else {
44+
let joined = dangerous_patterns
45+
.iter()
46+
.map(|pattern| format!("(?:{pattern})"))
47+
.collect::<Vec<_>>()
48+
.join("|");
49+
Some(Regex::new(&joined).map_err(|error| {
50+
PyErr::new::<pyo3::exceptions::PyValueError, _>(error.to_string())
51+
})?)
52+
};
4253

4354
let validator = Arc::new(CompiledValidator {
4455
max_param_length,
45-
dangerous_patterns: compiled_patterns,
56+
dangerous_pattern: combined_pattern,
4657
});
4758

4859
VALIDATOR_CACHE
@@ -52,19 +63,17 @@ fn get_validator(
5263
Ok(validator)
5364
}
5465

55-
fn validate_string(
56-
key: &str,
57-
value: &str,
58-
validator: &CompiledValidator,
59-
) -> Option<(String, String)> {
66+
fn validate_string(value: &str, validator: &CompiledValidator) -> Option<ValidationFailure> {
6067
if value.len() > validator.max_param_length {
61-
return Some((key.to_owned(), "max_length".to_owned()));
68+
return Some(ValidationFailure::MaxLength);
6269
}
6370

64-
for pattern in &validator.dangerous_patterns {
65-
if pattern.is_match(value) {
66-
return Some((key.to_owned(), "dangerous_pattern".to_owned()));
67-
}
71+
if validator
72+
.dangerous_pattern
73+
.as_ref()
74+
.is_some_and(|pattern| pattern.is_match(value))
75+
{
76+
return Some(ValidationFailure::DangerousPattern);
6877
}
6978

7079
None
@@ -76,11 +85,14 @@ fn walk_json_like(
7685
) -> PyResult<Option<(String, String)>> {
7786
if let Ok(dict) = data.cast::<PyDict>() {
7887
for (key, value) in dict.iter() {
79-
if value.is_instance_of::<PyString>() {
80-
let key_string = key.str()?.to_string_lossy().into_owned();
81-
let value_string = value.cast::<PyString>()?.to_str()?.to_owned();
82-
if let Some(result) = validate_string(&key_string, &value_string, validator) {
83-
return Ok(Some(result));
88+
if let Ok(value_string) = value.cast::<PyString>() {
89+
if let Some(result) = validate_string(value_string.to_str()?, validator) {
90+
let key_string = key.str()?.to_string_lossy().into_owned();
91+
let error_type = match result {
92+
ValidationFailure::MaxLength => "max_length",
93+
ValidationFailure::DangerousPattern => "dangerous_pattern",
94+
};
95+
return Ok(Some((key_string, error_type.to_owned())));
8496
}
8597
} else if value.is_instance_of::<PyDict>() || value.is_instance_of::<PyList>() {
8698
if let Some(result) = walk_json_like(&value, validator)? {

0 commit comments

Comments
 (0)