Skip to content

Commit ec4a6b4

Browse files
authored
Merge pull request #21 from dfinity/igor/fix-acme
Quick fix re-export
2 parents be42842 + 9c5ae7e commit ec4a6b4

File tree

7 files changed

+161
-30
lines changed

7 files changed

+161
-30
lines changed

.github/workflows/test.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,17 @@ jobs:
1818

1919
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6
2020

21-
- name: Run tests
22-
run: cargo test
21+
- name: Run tests (all features)
22+
run: cargo test --all-features
23+
24+
- name: Run tests (acme_dns)
25+
run: cargo test --features acme_dns
26+
27+
- name: Run tests (acme_alpn)
28+
run: cargo test --features acme_alpn
29+
30+
- name: Run tests (vector)
31+
run: cargo test --features vector
32+
33+
- name: Run tests (clients-hyper)
34+
run: cargo test --features clients-hyper

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ tower = { version = "0.5.1", features = ["util"] }
103103
tower-service = "0.3.3"
104104
tracing = "0.1.40"
105105
url = "2.5.3"
106-
uuid = { version = "1.16.0", features = ["v7"] }
106+
# DO NOT upgrade, this breaks monorepo compatibility
107+
# Read https://github.com/uuid-rs/uuid/releases/tag/1.13.0
108+
uuid = { version = "=1.12.1", features = ["v7"] }
107109
vrl = { version = "0.23.0", default-features = false, features = [
108110
"value",
109111
], optional = true }

src/http/cache.rs

Lines changed: 131 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,52 @@ use strum_macros::{Display, IntoStaticStr};
3131
use tokio::{select, sync::Mutex, time::sleep};
3232
use tokio_util::sync::CancellationToken;
3333

34+
use super::{Error as HttpError, body::buffer_body, calc_headers_size, extract_authority};
3435
use crate::{http::headers::X_CACHE_TTL, tasks::Run};
3536

36-
use super::{Error as HttpError, body::buffer_body, calc_headers_size, extract_authority};
37+
pub trait CustomBypassReason:
38+
Debug + Clone + std::fmt::Display + Into<&'static str> + PartialEq + Eq + Send + Sync + 'static
39+
{
40+
}
41+
42+
#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
43+
pub enum CustomBypassReasonDummy {}
44+
impl CustomBypassReason for CustomBypassReasonDummy {}
3745

3846
#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
3947
#[strum(serialize_all = "snake_case")]
40-
pub enum CacheBypassReason {
48+
pub enum CacheBypassReason<R: CustomBypassReason> {
4149
MethodNotCacheable,
4250
SizeUnknown,
4351
BodyTooBig,
4452
HTTPError,
4553
UnableToExtractKey,
54+
UnableToRunBypasser,
4655
CacheControl,
56+
Custom(R),
57+
}
58+
59+
impl<R: CustomBypassReason> CacheBypassReason<R> {
60+
pub fn into_str(self) -> &'static str {
61+
match self {
62+
Self::Custom(v) => v.into(),
63+
_ => self.into(),
64+
}
65+
}
4766
}
4867

4968
#[derive(Debug, Clone, Display, PartialEq, Eq, Default, IntoStaticStr)]
5069
#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
51-
pub enum CacheStatus {
70+
pub enum CacheStatus<R: CustomBypassReason = CustomBypassReasonDummy> {
5271
#[default]
5372
Disabled,
54-
Bypass(CacheBypassReason),
73+
Bypass(CacheBypassReason<R>),
5574
Hit,
5675
Miss,
5776
}
5877

5978
// Injects itself into a given response to be accessible by middleware
60-
impl CacheStatus {
79+
impl<B: CustomBypassReason> CacheStatus<B> {
6180
pub fn with_response<T>(self, mut resp: Response<T>) -> Response<T> {
6281
resp.extensions_mut().insert(self);
6382
resp
@@ -68,6 +87,8 @@ impl CacheStatus {
6887
pub enum Error {
6988
#[error("unable to extract key from request: {0}")]
7089
ExtractKey(String),
90+
#[error("unable to execute bypasser: {0}")]
91+
ExecuteBypasser(String),
7192
#[error("timed out while fetching body")]
7293
FetchBodyTimeout,
7394
#[error("body is too big")]
@@ -80,9 +101,9 @@ pub enum Error {
80101
Other(String),
81102
}
82103

83-
enum ResponseType {
104+
enum ResponseType<R: CustomBypassReason> {
84105
Fetched(Response<Bytes>, Duration),
85-
Streamed(Response, CacheBypassReason),
106+
Streamed(Response, CacheBypassReason<R>),
86107
}
87108

88109
#[derive(Clone)]
@@ -121,6 +142,26 @@ pub trait KeyExtractor: Clone + Send + Sync + Debug + 'static {
121142
fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, Error>;
122143
}
123144

145+
/// Trait to decide if we need to bypass caching of the given request
146+
pub trait Bypasser: Clone + Send + Sync + Debug + 'static {
147+
/// Custom bypass reason
148+
type BypassReason: CustomBypassReason;
149+
150+
/// Checks if we should bypass the given request
151+
fn bypass<T>(&self, req: &Request<T>) -> Result<Option<Self::BypassReason>, Error>;
152+
}
153+
154+
#[derive(Debug, Clone)]
155+
pub struct NoopBypasser;
156+
157+
impl Bypasser for NoopBypasser {
158+
type BypassReason = CustomBypassReasonDummy;
159+
160+
fn bypass<T>(&self, _req: &Request<T>) -> Result<Option<Self::BypassReason>, Error> {
161+
Ok(None)
162+
}
163+
}
164+
124165
#[derive(Clone)]
125166
pub struct Metrics {
126167
lock_await: HistogramVec,
@@ -245,8 +286,12 @@ fn infer_ttl<T>(req: &Response<T>) -> Option<CacheControl> {
245286
if ["no-cache", "no-store"].contains(&k) {
246287
Some(CacheControl::NoCache)
247288
} else if k == "max-age" {
248-
v.and_then(|x| x.parse::<u64>().ok())
249-
.map(|x| CacheControl::MaxAge(Duration::from_secs(x)))
289+
let v = v.and_then(|x| x.parse::<u64>().ok());
290+
if v == Some(0) {
291+
Some(CacheControl::NoCache)
292+
} else {
293+
v.map(|x| CacheControl::MaxAge(Duration::from_secs(x)))
294+
}
250295
} else {
251296
None
252297
}
@@ -268,16 +313,29 @@ impl<K: KeyExtractor> Expiry<K::Key, Arc<Entry>> for Expirer<K> {
268313
}
269314

270315
/// Builds a cache using some overridable defaults
271-
pub struct CacheBuilder<K: KeyExtractor> {
316+
pub struct CacheBuilder<K: KeyExtractor, B: Bypasser> {
272317
key_extractor: K,
318+
bypasser: Option<B>,
273319
opts: Opts,
274320
registry: Registry,
275321
}
276322

277-
impl<K: KeyExtractor> CacheBuilder<K> {
323+
impl<K: KeyExtractor> CacheBuilder<K, NoopBypasser> {
278324
pub fn new(key_extractor: K) -> Self {
279325
Self {
280326
key_extractor,
327+
bypasser: None,
328+
opts: Opts::default(),
329+
registry: Registry::new(),
330+
}
331+
}
332+
}
333+
334+
impl<K: KeyExtractor, B: Bypasser> CacheBuilder<K, B> {
335+
pub fn new_with_bypasser(key_extractor: K, bypasser: B) -> Self {
336+
Self {
337+
key_extractor,
338+
bypasser: Some(bypasser),
281339
opts: Opts::default(),
282340
registry: Registry::new(),
283341
}
@@ -344,15 +402,16 @@ impl<K: KeyExtractor> CacheBuilder<K> {
344402
}
345403

346404
/// Try to build the cache from this builder
347-
pub fn build(self) -> Result<Cache<K>, Error> {
348-
Cache::new(self.opts, self.key_extractor, &self.registry)
405+
pub fn build(self) -> Result<Cache<K, B>, Error> {
406+
Cache::new(self.opts, self.key_extractor, self.bypasser, &self.registry)
349407
}
350408
}
351409

352-
pub struct Cache<K: KeyExtractor> {
410+
pub struct Cache<K: KeyExtractor, B: Bypasser = NoopBypasser> {
353411
store: MokaCache<K::Key, Arc<Entry>, RandomState>,
354412
locks: MokaCache<K::Key, Arc<Mutex<()>>, RandomState>,
355413
key_extractor: K,
414+
bypasser: Option<B>,
356415
metrics: Metrics,
357416
opts: Opts,
358417
}
@@ -366,8 +425,13 @@ fn weigh_entry<K: KeyExtractor>(_k: &K::Key, v: &Arc<Entry>) -> u32 {
366425
size as u32
367426
}
368427

369-
impl<K: KeyExtractor + 'static> Cache<K> {
370-
pub fn new(opts: Opts, key_extractor: K, registry: &Registry) -> Result<Self, Error> {
428+
impl<K: KeyExtractor + 'static, B: Bypasser + 'static> Cache<K, B> {
429+
pub fn new(
430+
opts: Opts,
431+
key_extractor: K,
432+
bypasser: Option<B>,
433+
registry: &Registry,
434+
) -> Result<Self, Error> {
371435
if opts.max_item_size as u64 >= opts.cache_size {
372436
return Err(Error::Other(
373437
"Cache item size should be less than whole cache size".into(),
@@ -390,6 +454,7 @@ impl<K: KeyExtractor + 'static> Cache<K> {
390454
.build_with_hasher(RandomState::default()),
391455

392456
key_extractor,
457+
bypasser,
393458
metrics: Metrics::new(registry),
394459

395460
opts,
@@ -434,11 +499,11 @@ impl<K: KeyExtractor + 'static> Cache<K> {
434499
let (cache_status, response) = self.process_inner(now, request, next).await?;
435500

436501
// Record metrics
437-
let cache_bypass_reason_str: &'static str = match &cache_status {
438-
CacheStatus::Bypass(v) => v.into(),
502+
let cache_status_str: &'static str = (&cache_status).into();
503+
let cache_bypass_reason_str: &'static str = match cache_status.clone() {
504+
CacheStatus::Bypass(v) => v.into_str(),
439505
_ => "none",
440506
};
441-
let cache_status_str: &'static str = (&cache_status).into();
442507

443508
let labels = &[cache_status_str, cache_bypass_reason_str];
444509

@@ -456,7 +521,26 @@ impl<K: KeyExtractor + 'static> Cache<K> {
456521
now: Instant,
457522
request: Request,
458523
next: Next,
459-
) -> Result<(CacheStatus, Response), Error> {
524+
) -> Result<(CacheStatus<B::BypassReason>, Response), Error> {
525+
// Check if we have bypasser configured
526+
if let Some(b) = &self.bypasser {
527+
// Run it
528+
if let Ok(v) = b.bypass(&request) {
529+
// If it decided to bypass - return the custom reason
530+
if let Some(r) = v {
531+
return Ok((
532+
CacheStatus::Bypass(CacheBypassReason::Custom(r)),
533+
next.run(request).await,
534+
));
535+
}
536+
} else {
537+
return Ok((
538+
CacheStatus::Bypass(CacheBypassReason::UnableToRunBypasser),
539+
next.run(request).await,
540+
));
541+
}
542+
}
543+
460544
// Check the method
461545
if !self.opts.methods.contains(request.method()) {
462546
return Ok((
@@ -526,7 +610,11 @@ impl<K: KeyExtractor + 'static> Cache<K> {
526610
}
527611

528612
// Passes the request down the line and conditionally fetches the response body
529-
async fn pass_request(&self, request: Request, next: Next) -> Result<ResponseType, Error> {
613+
async fn pass_request(
614+
&self,
615+
request: Request,
616+
next: Next,
617+
) -> Result<ResponseType<B::BypassReason>, Error> {
530618
// Execute the response & get the headers
531619
let response = next.run(request).await;
532620

@@ -621,7 +709,7 @@ impl<K: KeyExtractor + 'static> Cache<K> {
621709
}
622710

623711
#[async_trait]
624-
impl<K: KeyExtractor> Run for Cache<K> {
712+
impl<K: KeyExtractor, B: Bypasser> Run for Cache<K, B> {
625713
async fn run(&self, _: CancellationToken) -> Result<(), anyhow::Error> {
626714
self.store.run_pending_tasks();
627715
self.metrics.memory.set(self.store.weighted_size() as i64);
@@ -748,6 +836,25 @@ mod tests {
748836
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR.into_response())
749837
}
750838

839+
#[test]
840+
fn test_bypass_reason_serialize() {
841+
#[derive(Debug, Clone, Display, PartialEq, Eq, IntoStaticStr)]
842+
#[strum(serialize_all = "snake_case")]
843+
enum CustomReasonTest {
844+
Bar,
845+
}
846+
impl CustomBypassReason for CustomReasonTest {}
847+
848+
let a: CacheBypassReason<CustomReasonTest> =
849+
CacheBypassReason::Custom(CustomReasonTest::Bar);
850+
let txt = a.into_str();
851+
assert_eq!(txt, "bar");
852+
853+
let a: CacheBypassReason<CustomReasonTest> = CacheBypassReason::BodyTooBig;
854+
let txt = a.into_str();
855+
assert_eq!(txt, "body_too_big");
856+
}
857+
751858
#[test]
752859
fn test_key_extractor_uri_range() {
753860
let x = KeyExtractorUriRange;
@@ -823,6 +930,8 @@ mod tests {
823930
infer_ttl(&req),
824931
Some(CacheControl::MaxAge(Duration::from_secs(86400)))
825932
);
933+
req.headers_mut().insert(CACHE_CONTROL, hval!("max-age=0"));
934+
assert_eq!(infer_ttl(&req), Some(CacheControl::NoCache));
826935

827936
req.headers_mut()
828937
.insert(CACHE_CONTROL, hval!("max-age=foo"));

src/http/client/mod.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ pub trait ClientWithStats: Client + Stats {
5050
fn to_client(self: Arc<Self>) -> Arc<dyn Client>;
5151
}
5252

53-
pub trait CloneableDnsResolver:
54-
Resolve + Service<HyperName> + Clone + fmt::Debug + 'static
55-
{
56-
}
53+
pub trait CloneableDnsResolver: Resolve + Clone + fmt::Debug + 'static {}
54+
55+
pub trait CloneableHyperDnsResolver: Service<HyperName> + Clone + fmt::Debug + 'static {}
5756

5857
#[derive(Clone, Debug)]
5958
struct Metrics {

src/http/dns.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use reqwest::dns::{Addrs, Name, Resolve, Resolving};
2121
use strum_macros::EnumString;
2222
use tower::Service;
2323

24-
use super::{Error, client::CloneableDnsResolver};
24+
use super::{
25+
Error,
26+
client::{CloneableDnsResolver, CloneableHyperDnsResolver},
27+
};
2528

2629
#[derive(Clone, Copy, Debug, EnumString)]
2730
#[strum(serialize_all = "snake_case")]
@@ -58,6 +61,7 @@ impl Default for Options {
5861
#[derive(Debug, Clone)]
5962
pub struct Resolver(Arc<TokioResolver>);
6063
impl CloneableDnsResolver for Resolver {}
64+
impl CloneableHyperDnsResolver for Resolver {}
6165

6266
impl Resolver {
6367
/// Creates a new resolver with given options.

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// Needed for certain macros
2+
#![recursion_limit = "256"]
13
#![warn(clippy::nursery)]
24
#![warn(tail_expr_drop_order)]
35

@@ -8,6 +10,8 @@ pub mod types;
810
#[cfg(feature = "vector")]
911
pub mod vector;
1012

13+
pub use prometheus;
14+
1115
/// Generic error
1216
#[derive(thiserror::Error, Debug)]
1317
pub enum Error {

src/tls/acme/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use async_trait::async_trait;
1212
use derive_new::new;
1313
use strum_macros::{Display, EnumString};
1414

15+
#[cfg(feature = "acme_dns")]
1516
pub use instant_acme;
1617

1718
#[derive(Clone, Display, EnumString, PartialEq, Eq)]

0 commit comments

Comments
 (0)