@@ -7,7 +7,12 @@ use std::sync::{Arc, Mutex};
77
88struct CompiledValidator {
99 max_param_length : usize ,
10- dangerous_patterns : Vec < Regex > ,
10+ dangerous_pattern : Option < Regex > ,
11+ }
12+
13+ enum ValidationFailure {
14+ MaxLength ,
15+ DangerousPattern ,
1116}
1217
1318static VALIDATOR_CACHE : Lazy < Mutex < HashMap < String , Arc < CompiledValidator > > > > =
@@ -32,17 +37,23 @@ fn get_validator(
3237 return Ok ( existing) ;
3338 }
3439
35- let compiled_patterns = dangerous_patterns
36- . iter ( )
37- . map ( |pattern| {
38- Regex :: new ( pattern)
39- . map_err ( |error| PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( error. to_string ( ) ) )
40- } )
41- . collect :: < PyResult < Vec < _ > > > ( ) ?;
40+ let combined_pattern =
41+ if dangerous_patterns. is_empty ( ) {
42+ None
43+ } else {
44+ let joined = dangerous_patterns
45+ . iter ( )
46+ . map ( |pattern| format ! ( "(?:{pattern})" ) )
47+ . collect :: < Vec < _ > > ( )
48+ . join ( "|" ) ;
49+ Some ( Regex :: new ( & joined) . map_err ( |error| {
50+ PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( error. to_string ( ) )
51+ } ) ?)
52+ } ;
4253
4354 let validator = Arc :: new ( CompiledValidator {
4455 max_param_length,
45- dangerous_patterns : compiled_patterns ,
56+ dangerous_pattern : combined_pattern ,
4657 } ) ;
4758
4859 VALIDATOR_CACHE
@@ -52,19 +63,17 @@ fn get_validator(
5263 Ok ( validator)
5364}
5465
55- fn validate_string (
56- key : & str ,
57- value : & str ,
58- validator : & CompiledValidator ,
59- ) -> Option < ( String , String ) > {
66+ fn validate_string ( value : & str , validator : & CompiledValidator ) -> Option < ValidationFailure > {
6067 if value. len ( ) > validator. max_param_length {
61- return Some ( ( key . to_owned ( ) , "max_length" . to_owned ( ) ) ) ;
68+ return Some ( ValidationFailure :: MaxLength ) ;
6269 }
6370
64- for pattern in & validator. dangerous_patterns {
65- if pattern. is_match ( value) {
66- return Some ( ( key. to_owned ( ) , "dangerous_pattern" . to_owned ( ) ) ) ;
67- }
71+ if validator
72+ . dangerous_pattern
73+ . as_ref ( )
74+ . is_some_and ( |pattern| pattern. is_match ( value) )
75+ {
76+ return Some ( ValidationFailure :: DangerousPattern ) ;
6877 }
6978
7079 None
@@ -76,11 +85,14 @@ fn walk_json_like(
7685) -> PyResult < Option < ( String , String ) > > {
7786 if let Ok ( dict) = data. cast :: < PyDict > ( ) {
7887 for ( key, value) in dict. iter ( ) {
79- if value. is_instance_of :: < PyString > ( ) {
80- let key_string = key. str ( ) ?. to_string_lossy ( ) . into_owned ( ) ;
81- let value_string = value. cast :: < PyString > ( ) ?. to_str ( ) ?. to_owned ( ) ;
82- if let Some ( result) = validate_string ( & key_string, & value_string, validator) {
83- return Ok ( Some ( result) ) ;
88+ if let Ok ( value_string) = value. cast :: < PyString > ( ) {
89+ if let Some ( result) = validate_string ( value_string. to_str ( ) ?, validator) {
90+ let key_string = key. str ( ) ?. to_string_lossy ( ) . into_owned ( ) ;
91+ let error_type = match result {
92+ ValidationFailure :: MaxLength => "max_length" ,
93+ ValidationFailure :: DangerousPattern => "dangerous_pattern" ,
94+ } ;
95+ return Ok ( Some ( ( key_string, error_type. to_owned ( ) ) ) ) ;
8496 }
8597 } else if value. is_instance_of :: < PyDict > ( ) || value. is_instance_of :: < PyList > ( ) {
8698 if let Some ( result) = walk_json_like ( & value, validator) ? {
0 commit comments