Skip to content

Commit 85a2dd6

Browse files
committed
Add auth middleware
1 parent 5771c60 commit 85a2dd6

File tree

6 files changed

+91
-13
lines changed

6 files changed

+91
-13
lines changed

lambda/benches/mix_node.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ fn bench_requests(c: &mut Criterion) {
7878
let enc_key = EncryptionKey::from(&Scalar::random(&mut rng) * &GENERATOR_TABLE);
7979

8080
let rt = tokio::runtime::Runtime::new().unwrap();
81-
let test_app = rt.block_on(async { testing::create_app().await });
81+
let test_app = rt.block_on(async { testing::create_app(None).await });
8282
let client = Arc::new(reqwest::Client::new());
8383

8484
let payload = Arc::new(EncryptedCodes {

lambda/src/bin/mix_node.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use lambda::mix_node;
1+
use lambda::mix_node::{self, AppState};
22

33
use mimalloc::MiMalloc as GlobalAllocator;
44

@@ -17,5 +17,6 @@ async fn main() -> Result<(), lambda_http::Error> {
1717

1818
lambda_http::tracing::init_default_subscriber();
1919

20-
lambda_http::run(mix_node::app()).await
20+
let state = AppState::new(std::env::var("AUTH_TOKEN").ok());
21+
lambda_http::run(mix_node::app(state)).await
2122
}

lambda/src/bin/simple_server.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use lambda::mix_node;
1+
use lambda::mix_node::{self, AppState};
22

33
use mimalloc::MiMalloc as GlobalAllocator;
44

@@ -11,6 +11,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1111
let port = listener.local_addr().unwrap().port();
1212

1313
println!("Listening on http://localhost:{port}...");
14-
axum::serve(listener, mix_node::app()).await.unwrap();
14+
let state = AppState::new(std::env::var("AUTH_TOKEN").ok());
15+
axum::serve(listener, mix_node::app(state)).await.unwrap();
1516
Ok(())
1617
}

lambda/src/lib.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,38 @@ pub const N_BITS: usize = 25600;
66
pub mod mix_node {
77
use super::*;
88

9-
use axum::extract::DefaultBodyLimit;
9+
use axum::extract::{DefaultBodyLimit, Request, State};
10+
use axum::middleware::{self, Next};
11+
use axum::response::{IntoResponse, Response};
1012
use axum::{response::Json, routing::post, Router};
1113

12-
use axum::http::StatusCode;
14+
use axum::http::{HeaderMap, StatusCode};
1315
use rand::{rngs::StdRng, SeedableRng};
1416
use rust_elgamal::{Ciphertext, EncryptionKey, RistrettoPoint};
1517
use serde::{Deserialize, Deserializer, Serialize};
1618
use std::sync::OnceLock;
1719

18-
pub fn app() -> Router {
20+
#[derive(Debug, Clone)]
21+
pub struct AppState {
22+
auth_token: Option<String>,
23+
}
24+
25+
impl AppState {
26+
pub fn new(auth_token: Option<String>) -> Self {
27+
Self { auth_token }
28+
}
29+
}
30+
31+
pub fn app(state: AppState) -> Router {
1932
Router::new()
2033
.route("/remix", post(remix_handler))
34+
.layer(middleware::from_fn_with_state(
35+
state.clone(),
36+
auth_middleware,
37+
))
2138
// TODO: for security reasons set max instead of disabling (measured payload was ~11MB)
2239
.layer(DefaultBodyLimit::disable())
40+
.with_state(state)
2341
}
2442

2543
#[derive(Debug, Serialize, Deserialize)]
@@ -51,6 +69,25 @@ pub mod mix_node {
5169
Ok(Json(codes))
5270
}
5371

72+
async fn auth_middleware(
73+
State(AppState { auth_token }): State<AppState>,
74+
headers: HeaderMap,
75+
request: Request,
76+
next: Next,
77+
) -> Response {
78+
let next_run = async { next.run(request).await };
79+
80+
match (auth_token, headers.get("X-AUTH-TOKEN")) {
81+
// AUTH_TOKEN is set on the server and in the request header so we check
82+
(Some(auth_token), Some(header_auth_token)) if auth_token == *header_auth_token => {
83+
next_run.await
84+
}
85+
// AUTH_TOKEN is not set on the server so we disable auth
86+
(None, _) => next_run.await,
87+
_ => (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(),
88+
}
89+
}
90+
5491
fn enc_key() -> &'static EncryptionKey {
5592
// TODO: remove hardcoded encryption key from a fixed seed
5693
static ENC_KEY: OnceLock<EncryptionKey> = OnceLock::new();

lambda/src/testing.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
use tokio::task::JoinHandle;
22

3-
use crate::mix_node;
3+
use crate::mix_node::{self, AppState};
44

55
pub struct TestApp {
66
pub port: u16,
77
pub join_handle: JoinHandle<()>,
88
}
99

10-
pub async fn create_app() -> TestApp {
10+
pub async fn create_app(auth_token: Option<String>) -> TestApp {
1111
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1212
let port = listener.local_addr().unwrap().port();
1313

1414
let join_handle = tokio::spawn(async move {
15-
axum::serve(listener, mix_node::app()).await.unwrap();
15+
let state = AppState::new(auth_token);
16+
axum::serve(listener, mix_node::app(state)).await.unwrap();
1617
});
1718

1819
TestApp { port, join_handle }

lambda/tests/mix_node_integration.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ fn set_up_payload() -> (EncryptedCodes, DecryptionKey) {
4343

4444
#[tokio::test]
4545
async fn test_mix_node() -> Result<(), Box<dyn Error>> {
46-
let TestApp { port, .. } = testing::create_app().await;
46+
let TestApp { port, .. } = testing::create_app(None).await;
4747

4848
let (codes, dec_key) = set_up_payload();
4949

@@ -78,7 +78,7 @@ async fn test_mix_node() -> Result<(), Box<dyn Error>> {
7878

7979
#[tokio::test]
8080
async fn test_mix_node_bad_request() -> Result<(), Box<dyn Error>> {
81-
let TestApp { port, .. } = testing::create_app().await;
81+
let TestApp { port, .. } = testing::create_app(None).await;
8282

8383
let (mut codes, _dec_key) = set_up_payload();
8484
// Remove elements to cause a size mismatch
@@ -98,6 +98,44 @@ async fn test_mix_node_bad_request() -> Result<(), Box<dyn Error>> {
9898
Ok(())
9999
}
100100

101+
#[tokio::test]
102+
async fn test_mix_node_unauthorized() -> Result<(), Box<dyn Error>> {
103+
let TestApp { port, .. } =
104+
testing::create_app(Some("test_mix_node_unauthorized".to_string())).await;
105+
106+
// Bad request + Serialization
107+
let client = reqwest::Client::new();
108+
let response = client
109+
.post(format!("http://localhost:{port}/remix"))
110+
.send()
111+
.await?;
112+
113+
// Assert
114+
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
115+
Ok(())
116+
}
117+
118+
#[tokio::test]
119+
async fn test_mix_node_authorized() -> Result<(), Box<dyn Error>> {
120+
let auth_token = "test_mix_node_authorized";
121+
let TestApp { port, .. } = testing::create_app(Some(auth_token.to_string())).await;
122+
123+
let (codes, _dec_key) = set_up_payload();
124+
125+
// Bad request + Serialization
126+
let client = reqwest::Client::new();
127+
let response = client
128+
.post(format!("http://localhost:{port}/remix"))
129+
.header("X-AUTH-TOKEN", auth_token)
130+
.json(&codes)
131+
.send()
132+
.await?;
133+
134+
// Assert
135+
assert_eq!(response.status(), StatusCode::OK);
136+
Ok(())
137+
}
138+
101139
#[test]
102140
fn test_encode_bits() {
103141
let bits = BitVec::<u8, Msb0>::from_slice(&[0b11100100]);

0 commit comments

Comments
 (0)