-
Notifications
You must be signed in to change notification settings - Fork 177
Expand file tree
/
Copy pathmirror_middleware.rs
More file actions
351 lines (294 loc) · 11.9 KB
/
mirror_middleware.rs
File metadata and controls
351 lines (294 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
//! Middleware to handle mirrors
use rattler_conda_types::utils::url_with_trailing_slash::UrlWithTrailingSlash;
use std::{
collections::HashMap,
sync::atomic::{self, AtomicUsize},
};
use http::Extensions;
use itertools::Itertools;
use reqwest::{Request, Response};
use reqwest_middleware::{Middleware, Next, Result};
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
/// Settings for the specific mirror (e.g. no zstd or bz2 support)
pub struct Mirror {
/// The url of this mirror
pub url: Url,
/// Disable zstd support (for repodata.json.zst files)
pub no_zstd: bool,
/// Disable bz2 support (for repodata.json.bz2 files)
pub no_bz2: bool,
/// Allowed number of failures before the mirror is considered dead
pub max_failures: Option<usize>,
}
#[allow(dead_code)]
struct MirrorState {
failures: AtomicUsize,
mirror: Mirror,
}
impl MirrorState {
pub fn add_failure(&self) {
self.failures.fetch_add(1, atomic::Ordering::Relaxed);
}
}
/// Middleware to handle mirrors
pub struct MirrorMiddleware {
mirror_map: HashMap<Url, Vec<MirrorState>>,
sorted_keys: Vec<(String, Url)>,
}
impl MirrorMiddleware {
/// Create a new `MirrorMiddleware` from a map of mirrors
pub fn from_map(mirror_map: HashMap<Url, Vec<Mirror>>) -> Self {
let mirror_map: HashMap<Url, Vec<MirrorState>> = mirror_map
.into_iter()
.map(|(url, mirrors)| {
let mirrors = mirrors
.into_iter()
.map(|mirror| MirrorState {
failures: AtomicUsize::new(0),
mirror,
})
.collect();
(url, mirrors)
})
.collect();
let sorted_keys = mirror_map
.keys()
.cloned()
.sorted_by(|a, b| b.path().len().cmp(&a.path().len()))
.map(|k| (k.to_string(), k.clone()))
.collect::<Vec<(String, Url)>>();
Self {
mirror_map,
sorted_keys,
}
}
/// Get sorted keys. The keys are sorted by length of the path,
/// so the longest path comes first.
pub fn keys(&self) -> &[(String, Url)] {
&self.sorted_keys
}
}
fn select_mirror(mirrors: &[MirrorState]) -> Option<&MirrorState> {
let mut min_failures = usize::MAX;
let mut min_failures_index = usize::MAX;
for (i, mirror) in mirrors.iter().enumerate() {
let failures = mirror.failures.load(atomic::Ordering::Relaxed);
if failures < min_failures && mirror.mirror.max_failures.is_none_or(|max| failures < max) {
min_failures = failures;
min_failures_index = i;
}
}
if min_failures_index == usize::MAX {
return None;
}
Some(&mirrors[min_failures_index])
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl Middleware for MirrorMiddleware {
async fn handle(
&self,
mut req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
let url_str = req.url().to_string();
for (key, url) in self.keys() {
if let Some(url_rest) = url_str.strip_prefix(key) {
let url_rest = url_rest.trim_start_matches('/');
// replace the key with the mirror
let mirrors = self.mirror_map.get(url).unwrap();
let selected_mirror = select_mirror(mirrors);
let Some(selected_mirror) = selected_mirror else {
return Ok(create_404_response(req.url(), "All mirrors are dead"));
};
let mirror = &selected_mirror.mirror;
let base = UrlWithTrailingSlash::from(mirror.url.clone());
let selected_url = base.as_ref().join(url_rest).unwrap();
// Short-circuit if the mirror does not support the file type
if url_rest.ends_with(".json.zst") && mirror.no_zstd {
return Ok(create_404_response(
&selected_url,
"Mirror does not support zstd",
));
}
if url_rest.ends_with(".json.bz2") && mirror.no_bz2 {
return Ok(create_404_response(
&selected_url,
"Mirror does not support bz2",
));
}
*req.url_mut() = selected_url;
let res = next.run(req, extensions).await;
// record a failure if the request failed so we can avoid the mirror in the future
match res.as_ref() {
Ok(res) if res.status().is_server_error() => selected_mirror.add_failure(),
Err(_) => selected_mirror.add_failure(),
_ => {}
}
return res;
}
}
// if we don't have a mirror, we don't need to do anything
next.run(req, extensions).await
}
}
#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_404_response(url: &Url, body: &str) -> Response {
use reqwest::ResponseBuilderExt;
Response::from(
http::response::Builder::new()
.status(http::StatusCode::NOT_FOUND)
.url(url.clone())
.body(body.to_string())
.unwrap(),
)
}
#[cfg(target_arch = "wasm32")]
pub(crate) fn create_404_response(_url: &Url, _body: &str) -> Response {
todo!("This is not implemented in reqwest, we need to contribute that.")
}
#[cfg(test)]
mod test {
use std::{future::IntoFuture, net::SocketAddr};
use axum::{extract::State, http::StatusCode, routing::get, Router};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use url::Url;
use crate::MirrorMiddleware;
use super::Mirror;
async fn count(State(name): State<String>) -> String {
format!("Hi from counter: {name}")
}
async fn broken_return() -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
async fn test_server(name: &str, broken: bool) -> Url {
let state = String::from(name);
// Construct a router that returns data from the static dir but fails the first try.
let router = if broken {
Router::new().route("/count", get(broken_return))
} else {
Router::new().route("/count", get(count)).with_state(state)
};
let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
let service = router.into_make_service();
tokio::spawn(axum::serve(listener, service).into_future());
format!("http://{}:{}", addr.ip(), addr.port())
.parse()
.unwrap()
}
#[tokio::test]
async fn test_mirror_middleware() {
let addr_1 = test_server("server 1", false).await;
let addr_2 = test_server("server 2", false).await;
let mut mirror_map = std::collections::HashMap::new();
mirror_map.insert(
"http://bla.com".parse().unwrap(),
vec![mirror_setting(addr_1), mirror_setting(addr_2)],
);
let middleware = crate::MirrorMiddleware::from_map(mirror_map);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(middleware)
.build();
let res = client.get("http://bla.com/count").send().await.unwrap();
assert!(res.status().is_success());
let res = res.text().await.unwrap();
println!("{res}");
// should always take the first element from the list
assert!(res == "Hi from counter: server 1");
}
fn mirror_setting(url: Url) -> Mirror {
Mirror {
url,
no_zstd: false,
no_bz2: false,
max_failures: Some(3),
}
}
#[tokio::test]
async fn test_mirror_middleware_broken() {
let addr_1 = test_server("server 1", true).await;
let addr_2 = test_server("server 2", false).await;
let mut mirror_map = std::collections::HashMap::new();
mirror_map.insert(
"http://bla.com".parse().unwrap(),
vec![mirror_setting(addr_1), mirror_setting(addr_2)],
);
let middleware = MirrorMiddleware::from_map(mirror_map.clone());
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(middleware)
.build();
let res = client.get("http://bla.com/count").send().await.unwrap();
assert!(res.status().is_server_error());
// only the second server should be used
let res = client.get("http://bla.com/count").send().await.unwrap();
assert!(res.status().is_success());
assert!(res.text().await.unwrap() == "Hi from counter: server 2");
// add retry handler
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let middleware = MirrorMiddleware::from_map(mirror_map);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
// retry middleware has to come before the mirror middleware
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.with(middleware)
.build();
let res = client.get("http://bla.com/count").send().await.unwrap();
assert!(res.status().is_success());
assert!(res.text().await.unwrap() == "Hi from counter: server 2");
}
#[test]
fn test_mirror_sort() {
let keys: Vec<Url> = vec![
"http://bla.com/abc/def".parse().unwrap(),
"http://bla.com/abc".parse().unwrap(),
"http://bla.com/abc/def/ghi".parse().unwrap(),
];
let mirror_middleware =
MirrorMiddleware::from_map(keys.into_iter().map(|k| (k.clone(), vec![])).collect());
let mut len = mirror_middleware.keys()[0].0.len();
for path in mirror_middleware.keys().iter() {
assert!(path.0.len() <= len);
len = path.0.len();
}
}
#[tokio::test]
async fn test_mirror_middleware_path_rewrite() {
// Start a server that serves at /channel/count
let state = String::from("mirror server");
let router = Router::new()
.route("/channel/count", get(count))
.with_state(state);
let addr = SocketAddr::new([127, 0, 0, 1].into(), 0);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(axum::serve(listener, router.into_make_service()).into_future());
let mirror_url: Url = format!("http://{}:{}/channel", addr.ip(), addr.port())
.parse()
.unwrap();
let mut mirror_map = std::collections::HashMap::new();
// Upstream key includes a path segment (e.g. conda-forge)
// Mirror URL also has a path segment (e.g. channel)
// The mirror path must fully replace the upstream path.
mirror_map.insert(
"https://prefix.dev/conda-forge".parse().unwrap(),
vec![mirror_setting(mirror_url)],
);
let middleware = MirrorMiddleware::from_map(mirror_map);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(middleware)
.build();
// Request to upstream: https://prefix.dev/conda-forge/count
// Should be rewritten to: http://127.0.0.1:PORT/channel/count
let res = client
.get("https://prefix.dev/conda-forge/count")
.send()
.await
.unwrap();
assert!(res.status().is_success(), "status: {}", res.status());
let body = res.text().await.unwrap();
assert_eq!(body, "Hi from counter: mirror server");
}
}