Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/infrastructure/http/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod request_origin;
pub mod user_agent;
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ where
type Rejection = std::convert::Infallible;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let headers = HeaderMap::from_request_parts(parts, state)
.await
.expect("HeaderMap extractor should never fail");
let ConnectInfo(addr) = ConnectInfo::<SocketAddr>::from_request_parts(parts, state)
let headers = HeaderMap::from_request_parts(parts, state).await?;

let addr = ConnectInfo::<SocketAddr>::from_request_parts(parts, state)
.await
.expect("ConnectInfo extractor should never fail");
.ok()
.map(|ConnectInfo(addr)| addr);

let ip = maybe_x_forwarded_for(&headers)
.or_else(|| maybe_x_real_ip(&headers))
.unwrap_or_else(|| addr.ip());
.or_else(|| addr.map(|a| a.ip()))
.unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));

Ok(RequestOrigin(ip))
}
}
Expand Down
62 changes: 62 additions & 0 deletions src/infrastructure/http/extractors/user_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use axum::{
extract::FromRequestParts,
http::{HeaderMap, request::Parts},
};
/// Axum extractor for User-Agent header
///
/// Extracts the User-Agent header from HTTP requests.
/// Returns None if the header is missing or cannot be parsed as a valid string.
pub struct UserAgent(pub Option<String>);

impl<S> FromRequestParts<S> for UserAgent
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let headers = HeaderMap::from_request_parts(parts, state).await?;

let user_agent = headers
.get("user-agent")
.and_then(|hv| hv.to_str().ok())
.map(|s| s.to_string());

Ok(UserAgent(user_agent))
}
}

#[cfg(test)]
mod tests {
use super::*;
use axum::{extract::Request, http::HeaderMap};

// Helper function to create a mock request with headers
fn make_request(headers: HeaderMap) -> Request<()> {
let mut request = Request::builder().body(()).unwrap();
*request.headers_mut() = headers;
request
}

#[tokio::test]
async fn test_user_agent_present() {
let mut headers = HeaderMap::new();
headers.insert(
"user-agent",
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36"
.parse()
.unwrap(),
);

let request = make_request(headers);
let (mut parts, _) = request.into_parts();
let UserAgent(result) = UserAgent::from_request_parts(&mut parts, &())
.await
.unwrap();

assert_eq!(
result,
Some("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36".to_string())
);
}
}
2 changes: 1 addition & 1 deletion src/infrastructure/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ pub mod extractors;
mod server;

pub use error::HttpServerError;
pub use extractors::RequestOrigin;
pub use extractors::{request_origin::RequestOrigin, user_agent::UserAgent};
pub use server::HttpServer;
5 changes: 3 additions & 2 deletions src/sms_verification/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::sms_verification::{
};
use crate::{
EnvConfig,
infrastructure::http::{HttpServerError, RequestOrigin},
infrastructure::http::{HttpServerError, RequestOrigin, UserAgent},
sms_verification::app_state::AppState,
};

Expand Down Expand Up @@ -42,11 +42,12 @@ pub async fn router_with_db(
async fn send_code_handler(
State(mut state): State<AppState>,
RequestOrigin(ip_address): RequestOrigin,
UserAgent(user_agent): UserAgent,
Json(request): Json<CreateVerificationRequest>,
) -> Result<StatusCode, SmsVerificationError> {
state
.sms_verification
.create_verification(&state.db, request, ip_address)
.create_verification(&state.db, request, ip_address, user_agent)
.await?;
Ok(StatusCode::OK)
}
Expand Down
18 changes: 13 additions & 5 deletions src/sms_verification/prelude_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ struct Target {
#[derive(Serialize)]
struct Signals {
ip_address: String,
#[serde(skip_serializing_if = "Option::is_none")]
user_agent: Option<String>,
}

#[derive(Serialize)]
struct PreludeCreateVerificationRequest {
target: Target,
signals: Signals,
#[serde(skip_serializing_if = "Option::is_none")]
signals: Option<Signals>,
dispatch_id: Option<String>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -215,16 +219,20 @@ impl PreludeAPI {
pub async fn create_verification(
&self,
phone_number: &PhoneNumber,
ip_address: Option<IpAddr>,
ip_address: IpAddr,
user_agent: Option<String>,
dispatch_id: Option<String>,
) -> Result<PreludeCreateVerificationResponse, PreludeError> {
let request_body = PreludeCreateVerificationRequest {
target: Target {
target_type: "phone_number".to_string(),
value: phone_number.to_string(),
},
signals: ip_address.map(|ip| Signals {
ip_address: ip.to_string(),
}),
signals: Signals {
ip_address: ip_address.to_string(),
user_agent,
},
dispatch_id,
};

let url = self
Expand Down
8 changes: 7 additions & 1 deletion src/sms_verification/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl SmsVerificationService {
db: &SqlDb,
request: CreateVerificationRequest,
ip_address: IpAddr,
user_agent: Option<String>,
) -> Result<(), SmsVerificationError> {
let phone_number_hash = self
.hasher_argon2id
Expand All @@ -94,7 +95,12 @@ impl SmsVerificationService {

let prelude_response = self
.prelude_api
.create_verification(&request.phone_number, Some(ip_address))
.create_verification(
&request.phone_number,
ip_address,
user_agent,
request.dispatch_id,
)
.await?;

let id = match &prelude_response {
Expand Down
Loading
Loading