@@ -6,20 +6,38 @@ pub const N_BITS: usize = 25600;
66pub mod mix_node {
77 use super :: * ;
88
9- use axum:: extract:: DefaultBodyLimit ;
9+ use axum:: extract:: { DefaultBodyLimit , Request , State } ;
10+ use axum:: middleware:: { self , Next } ;
11+ use axum:: response:: { IntoResponse , Response } ;
1012 use axum:: { response:: Json , routing:: post, Router } ;
1113
12- use axum:: http:: StatusCode ;
14+ use axum:: http:: { HeaderMap , StatusCode } ;
1315 use rand:: { rngs:: StdRng , SeedableRng } ;
1416 use rust_elgamal:: { Ciphertext , EncryptionKey , RistrettoPoint } ;
1517 use serde:: { Deserialize , Deserializer , Serialize } ;
1618 use std:: sync:: OnceLock ;
1719
18- pub fn app ( ) -> Router {
20+ #[ derive( Debug , Clone ) ]
21+ pub struct AppState {
22+ auth_token : Option < String > ,
23+ }
24+
25+ impl AppState {
26+ pub fn new ( auth_token : Option < String > ) -> Self {
27+ Self { auth_token }
28+ }
29+ }
30+
31+ pub fn app ( state : AppState ) -> Router {
1932 Router :: new ( )
2033 . route ( "/remix" , post ( remix_handler) )
34+ . layer ( middleware:: from_fn_with_state (
35+ state. clone ( ) ,
36+ auth_middleware,
37+ ) )
2138 // TODO: for security reasons set max instead of disabling (measured payload was ~11MB)
2239 . layer ( DefaultBodyLimit :: disable ( ) )
40+ . with_state ( state)
2341 }
2442
2543 #[ derive( Debug , Serialize , Deserialize ) ]
@@ -51,6 +69,25 @@ pub mod mix_node {
5169 Ok ( Json ( codes) )
5270 }
5371
72+ async fn auth_middleware (
73+ State ( AppState { auth_token } ) : State < AppState > ,
74+ headers : HeaderMap ,
75+ request : Request ,
76+ next : Next ,
77+ ) -> Response {
78+ let next_run = async { next. run ( request) . await } ;
79+
80+ match ( auth_token, headers. get ( "X-AUTH-TOKEN" ) ) {
81+ // AUTH_TOKEN is set on the server and in the request header so we check
82+ ( Some ( auth_token) , Some ( header_auth_token) ) if auth_token == * header_auth_token => {
83+ next_run. await
84+ }
85+ // AUTH_TOKEN is not set on the server so we disable auth
86+ ( None , _) => next_run. await ,
87+ _ => ( StatusCode :: UNAUTHORIZED , "Unauthorized" ) . into_response ( ) ,
88+ }
89+ }
90+
5491 fn enc_key ( ) -> & ' static EncryptionKey {
5592 // TODO: remove hardcoded encryption key from a fixed seed
5693 static ENC_KEY : OnceLock < EncryptionKey > = OnceLock :: new ( ) ;
0 commit comments