@@ -31,33 +31,52 @@ use strum_macros::{Display, IntoStaticStr};
3131use tokio:: { select, sync:: Mutex , time:: sleep} ;
3232use tokio_util:: sync:: CancellationToken ;
3333
34+ use super :: { Error as HttpError , body:: buffer_body, calc_headers_size, extract_authority} ;
3435use crate :: { http:: headers:: X_CACHE_TTL , tasks:: Run } ;
3536
36- use super :: { Error as HttpError , body:: buffer_body, calc_headers_size, extract_authority} ;
37+ pub trait CustomBypassReason :
38+ Debug + Clone + std:: fmt:: Display + Into < & ' static str > + PartialEq + Eq + Send + Sync + ' static
39+ {
40+ }
41+
42+ #[ derive( Debug , Clone , Display , PartialEq , Eq , IntoStaticStr ) ]
43+ pub enum CustomBypassReasonDummy { }
44+ impl CustomBypassReason for CustomBypassReasonDummy { }
3745
3846#[ derive( Debug , Clone , Display , PartialEq , Eq , IntoStaticStr ) ]
3947#[ strum( serialize_all = "snake_case" ) ]
40- pub enum CacheBypassReason {
48+ pub enum CacheBypassReason < R : CustomBypassReason > {
4149 MethodNotCacheable ,
4250 SizeUnknown ,
4351 BodyTooBig ,
4452 HTTPError ,
4553 UnableToExtractKey ,
54+ UnableToRunBypasser ,
4655 CacheControl ,
56+ Custom ( R ) ,
57+ }
58+
59+ impl < R : CustomBypassReason > CacheBypassReason < R > {
60+ pub fn into_str ( self ) -> & ' static str {
61+ match self {
62+ Self :: Custom ( v) => v. into ( ) ,
63+ _ => self . into ( ) ,
64+ }
65+ }
4766}
4867
4968#[ derive( Debug , Clone , Display , PartialEq , Eq , Default , IntoStaticStr ) ]
5069#[ strum( serialize_all = "SCREAMING_SNAKE_CASE" ) ]
51- pub enum CacheStatus {
70+ pub enum CacheStatus < R : CustomBypassReason = CustomBypassReasonDummy > {
5271 #[ default]
5372 Disabled ,
54- Bypass ( CacheBypassReason ) ,
73+ Bypass ( CacheBypassReason < R > ) ,
5574 Hit ,
5675 Miss ,
5776}
5877
5978// Injects itself into a given response to be accessible by middleware
60- impl CacheStatus {
79+ impl < B : CustomBypassReason > CacheStatus < B > {
6180 pub fn with_response < T > ( self , mut resp : Response < T > ) -> Response < T > {
6281 resp. extensions_mut ( ) . insert ( self ) ;
6382 resp
@@ -68,6 +87,8 @@ impl CacheStatus {
6887pub enum Error {
6988 #[ error( "unable to extract key from request: {0}" ) ]
7089 ExtractKey ( String ) ,
90+ #[ error( "unable to execute bypasser: {0}" ) ]
91+ ExecuteBypasser ( String ) ,
7192 #[ error( "timed out while fetching body" ) ]
7293 FetchBodyTimeout ,
7394 #[ error( "body is too big" ) ]
@@ -80,9 +101,9 @@ pub enum Error {
80101 Other ( String ) ,
81102}
82103
83- enum ResponseType {
104+ enum ResponseType < R : CustomBypassReason > {
84105 Fetched ( Response < Bytes > , Duration ) ,
85- Streamed ( Response , CacheBypassReason ) ,
106+ Streamed ( Response , CacheBypassReason < R > ) ,
86107}
87108
88109#[ derive( Clone ) ]
@@ -121,6 +142,26 @@ pub trait KeyExtractor: Clone + Send + Sync + Debug + 'static {
121142 fn extract < T > ( & self , req : & Request < T > ) -> Result < Self :: Key , Error > ;
122143}
123144
145+ /// Trait to decide if we need to bypass caching of the given request
146+ pub trait Bypasser : Clone + Send + Sync + Debug + ' static {
147+ /// Custom bypass reason
148+ type BypassReason : CustomBypassReason ;
149+
150+ /// Checks if we should bypass the given request
151+ fn bypass < T > ( & self , req : & Request < T > ) -> Result < Option < Self :: BypassReason > , Error > ;
152+ }
153+
154+ #[ derive( Debug , Clone ) ]
155+ pub struct NoopBypasser ;
156+
157+ impl Bypasser for NoopBypasser {
158+ type BypassReason = CustomBypassReasonDummy ;
159+
160+ fn bypass < T > ( & self , _req : & Request < T > ) -> Result < Option < Self :: BypassReason > , Error > {
161+ Ok ( None )
162+ }
163+ }
164+
124165#[ derive( Clone ) ]
125166pub struct Metrics {
126167 lock_await : HistogramVec ,
@@ -245,8 +286,12 @@ fn infer_ttl<T>(req: &Response<T>) -> Option<CacheControl> {
245286 if [ "no-cache" , "no-store" ] . contains ( & k) {
246287 Some ( CacheControl :: NoCache )
247288 } else if k == "max-age" {
248- v. and_then ( |x| x. parse :: < u64 > ( ) . ok ( ) )
249- . map ( |x| CacheControl :: MaxAge ( Duration :: from_secs ( x) ) )
289+ let v = v. and_then ( |x| x. parse :: < u64 > ( ) . ok ( ) ) ;
290+ if v == Some ( 0 ) {
291+ Some ( CacheControl :: NoCache )
292+ } else {
293+ v. map ( |x| CacheControl :: MaxAge ( Duration :: from_secs ( x) ) )
294+ }
250295 } else {
251296 None
252297 }
@@ -268,16 +313,29 @@ impl<K: KeyExtractor> Expiry<K::Key, Arc<Entry>> for Expirer<K> {
268313}
269314
270315/// Builds a cache using some overridable defaults
271- pub struct CacheBuilder < K : KeyExtractor > {
316+ pub struct CacheBuilder < K : KeyExtractor , B : Bypasser > {
272317 key_extractor : K ,
318+ bypasser : Option < B > ,
273319 opts : Opts ,
274320 registry : Registry ,
275321}
276322
277- impl < K : KeyExtractor > CacheBuilder < K > {
323+ impl < K : KeyExtractor > CacheBuilder < K , NoopBypasser > {
278324 pub fn new ( key_extractor : K ) -> Self {
279325 Self {
280326 key_extractor,
327+ bypasser : None ,
328+ opts : Opts :: default ( ) ,
329+ registry : Registry :: new ( ) ,
330+ }
331+ }
332+ }
333+
334+ impl < K : KeyExtractor , B : Bypasser > CacheBuilder < K , B > {
335+ pub fn new_with_bypasser ( key_extractor : K , bypasser : B ) -> Self {
336+ Self {
337+ key_extractor,
338+ bypasser : Some ( bypasser) ,
281339 opts : Opts :: default ( ) ,
282340 registry : Registry :: new ( ) ,
283341 }
@@ -344,15 +402,16 @@ impl<K: KeyExtractor> CacheBuilder<K> {
344402 }
345403
346404 /// Try to build the cache from this builder
347- pub fn build ( self ) -> Result < Cache < K > , Error > {
348- Cache :: new ( self . opts , self . key_extractor , & self . registry )
405+ pub fn build ( self ) -> Result < Cache < K , B > , Error > {
406+ Cache :: new ( self . opts , self . key_extractor , self . bypasser , & self . registry )
349407 }
350408}
351409
352- pub struct Cache < K : KeyExtractor > {
410+ pub struct Cache < K : KeyExtractor , B : Bypasser = NoopBypasser > {
353411 store : MokaCache < K :: Key , Arc < Entry > , RandomState > ,
354412 locks : MokaCache < K :: Key , Arc < Mutex < ( ) > > , RandomState > ,
355413 key_extractor : K ,
414+ bypasser : Option < B > ,
356415 metrics : Metrics ,
357416 opts : Opts ,
358417}
@@ -366,8 +425,13 @@ fn weigh_entry<K: KeyExtractor>(_k: &K::Key, v: &Arc<Entry>) -> u32 {
366425 size as u32
367426}
368427
369- impl < K : KeyExtractor + ' static > Cache < K > {
370- pub fn new ( opts : Opts , key_extractor : K , registry : & Registry ) -> Result < Self , Error > {
428+ impl < K : KeyExtractor + ' static , B : Bypasser + ' static > Cache < K , B > {
429+ pub fn new (
430+ opts : Opts ,
431+ key_extractor : K ,
432+ bypasser : Option < B > ,
433+ registry : & Registry ,
434+ ) -> Result < Self , Error > {
371435 if opts. max_item_size as u64 >= opts. cache_size {
372436 return Err ( Error :: Other (
373437 "Cache item size should be less than whole cache size" . into ( ) ,
@@ -390,6 +454,7 @@ impl<K: KeyExtractor + 'static> Cache<K> {
390454 . build_with_hasher ( RandomState :: default ( ) ) ,
391455
392456 key_extractor,
457+ bypasser,
393458 metrics : Metrics :: new ( registry) ,
394459
395460 opts,
@@ -434,11 +499,11 @@ impl<K: KeyExtractor + 'static> Cache<K> {
434499 let ( cache_status, response) = self . process_inner ( now, request, next) . await ?;
435500
436501 // Record metrics
437- let cache_bypass_reason_str: & ' static str = match & cache_status {
438- CacheStatus :: Bypass ( v) => v. into ( ) ,
502+ let cache_status_str: & ' static str = ( & cache_status) . into ( ) ;
503+ let cache_bypass_reason_str: & ' static str = match cache_status. clone ( ) {
504+ CacheStatus :: Bypass ( v) => v. into_str ( ) ,
439505 _ => "none" ,
440506 } ;
441- let cache_status_str: & ' static str = ( & cache_status) . into ( ) ;
442507
443508 let labels = & [ cache_status_str, cache_bypass_reason_str] ;
444509
@@ -456,7 +521,26 @@ impl<K: KeyExtractor + 'static> Cache<K> {
456521 now : Instant ,
457522 request : Request ,
458523 next : Next ,
459- ) -> Result < ( CacheStatus , Response ) , Error > {
524+ ) -> Result < ( CacheStatus < B :: BypassReason > , Response ) , Error > {
525+ // Check if we have bypasser configured
526+ if let Some ( b) = & self . bypasser {
527+ // Run it
528+ if let Ok ( v) = b. bypass ( & request) {
529+ // If it decided to bypass - return the custom reason
530+ if let Some ( r) = v {
531+ return Ok ( (
532+ CacheStatus :: Bypass ( CacheBypassReason :: Custom ( r) ) ,
533+ next. run ( request) . await ,
534+ ) ) ;
535+ }
536+ } else {
537+ return Ok ( (
538+ CacheStatus :: Bypass ( CacheBypassReason :: UnableToRunBypasser ) ,
539+ next. run ( request) . await ,
540+ ) ) ;
541+ }
542+ }
543+
460544 // Check the method
461545 if !self . opts . methods . contains ( request. method ( ) ) {
462546 return Ok ( (
@@ -526,7 +610,11 @@ impl<K: KeyExtractor + 'static> Cache<K> {
526610 }
527611
528612 // Passes the request down the line and conditionally fetches the response body
529- async fn pass_request ( & self , request : Request , next : Next ) -> Result < ResponseType , Error > {
613+ async fn pass_request (
614+ & self ,
615+ request : Request ,
616+ next : Next ,
617+ ) -> Result < ResponseType < B :: BypassReason > , Error > {
530618 // Execute the response & get the headers
531619 let response = next. run ( request) . await ;
532620
@@ -621,7 +709,7 @@ impl<K: KeyExtractor + 'static> Cache<K> {
621709}
622710
623711#[ async_trait]
624- impl < K : KeyExtractor > Run for Cache < K > {
712+ impl < K : KeyExtractor , B : Bypasser > Run for Cache < K , B > {
625713 async fn run ( & self , _: CancellationToken ) -> Result < ( ) , anyhow:: Error > {
626714 self . store . run_pending_tasks ( ) ;
627715 self . metrics . memory . set ( self . store . weighted_size ( ) as i64 ) ;
@@ -748,6 +836,25 @@ mod tests {
748836 . unwrap_or ( StatusCode :: INTERNAL_SERVER_ERROR . into_response ( ) )
749837 }
750838
839+ #[ test]
840+ fn test_bypass_reason_serialize ( ) {
841+ #[ derive( Debug , Clone , Display , PartialEq , Eq , IntoStaticStr ) ]
842+ #[ strum( serialize_all = "snake_case" ) ]
843+ enum CustomReasonTest {
844+ Bar ,
845+ }
846+ impl CustomBypassReason for CustomReasonTest { }
847+
848+ let a: CacheBypassReason < CustomReasonTest > =
849+ CacheBypassReason :: Custom ( CustomReasonTest :: Bar ) ;
850+ let txt = a. into_str ( ) ;
851+ assert_eq ! ( txt, "bar" ) ;
852+
853+ let a: CacheBypassReason < CustomReasonTest > = CacheBypassReason :: BodyTooBig ;
854+ let txt = a. into_str ( ) ;
855+ assert_eq ! ( txt, "body_too_big" ) ;
856+ }
857+
751858 #[ test]
752859 fn test_key_extractor_uri_range ( ) {
753860 let x = KeyExtractorUriRange ;
@@ -823,6 +930,8 @@ mod tests {
823930 infer_ttl( & req) ,
824931 Some ( CacheControl :: MaxAge ( Duration :: from_secs( 86400 ) ) )
825932 ) ;
933+ req. headers_mut ( ) . insert ( CACHE_CONTROL , hval ! ( "max-age=0" ) ) ;
934+ assert_eq ! ( infer_ttl( & req) , Some ( CacheControl :: NoCache ) ) ;
826935
827936 req. headers_mut ( )
828937 . insert ( CACHE_CONTROL , hval ! ( "max-age=foo" ) ) ;
0 commit comments