diff --git a/src/clients.rs b/src/clients.rs index 3c7c201b..8f626422 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -254,7 +254,11 @@ pub async fn create_http_client( .layer(http_trace_layer()) .layer(TimeoutLayer::new(request_timeout)) .service(client); - Ok(HttpClient::new(base_url, client)) + Ok(HttpClient::new( + base_url, + service_config.api_token.clone(), + client, + )) } pub async fn create_grpc_client( diff --git a/src/clients/http.rs b/src/clients/http.rs index 90a6c6ad..e962c578 100644 --- a/src/clients/http.rs +++ b/src/clients/http.rs @@ -119,15 +119,17 @@ pub trait HttpClientExt: Client { pub struct HttpClient { base_url: Url, health_url: Url, + api_token: Option, inner: HttpClientInner, } impl HttpClient { - pub fn new(base_url: Url, inner: HttpClientInner) -> Self { + pub fn new(base_url: Url, api_token: Option, inner: HttpClientInner) -> Self { let health_url = base_url.join("health").unwrap(); Self { base_url, health_url, + api_token, inner, } } @@ -140,6 +142,19 @@ impl HttpClient { self.base_url.join(path).unwrap() } + /// Injects the API token as a Bearer token in the Authorization header if configured and present in the environment. + fn inject_api_token(&self, headers: &mut HeaderMap) -> Result<(), Error> { + if let Some(token) = &self.api_token { + headers.insert( + http::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", token)).map_err(|e| Error::Http { + code: StatusCode::INTERNAL_SERVER_ERROR, + message: format!("invalid authorization header: {e}"), + })?, + ); + } + Ok(()) + } pub async fn get( &self, url: Url, @@ -162,11 +177,14 @@ impl HttpClient { &self, url: Url, method: Method, - headers: HeaderMap, + mut headers: HeaderMap, body: impl RequestBody, ) -> Result { let ctx = Span::current().context(); + + self.inject_api_token(&mut headers)?; let headers = trace::with_traceparent_header(&ctx, headers); + let mut builder = hyper::http::request::Builder::new() .method(method) .uri(url.as_uri()); diff --git a/src/config.rs b/src/config.rs index bc8a20d5..956c5d8d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,7 +23,10 @@ use std::{ use serde::Deserialize; use tracing::{debug, error, info, warn}; -use crate::clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}; +use crate::{ + clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname}, + utils::from_env, +}; /// Default allowed headers to passthrough to clients. const DEFAULT_ALLOWED_HEADERS: &[&str] = &[]; @@ -86,6 +89,9 @@ pub struct ServiceConfig { pub http2_keep_alive_interval: Option, /// Keep-alive timeout in seconds for client calls [currently only for grpc generation] pub keep_alive_timeout: Option, + /// Name of environment variable that contains the API key to use for this service [currently only for http generation] + #[serde(default, deserialize_with = "from_env")] + pub api_token: Option, } impl ServiceConfig { @@ -101,6 +107,7 @@ impl ServiceConfig { max_retries: None, http2_keep_alive_interval: None, keep_alive_timeout: None, + api_token: None, } } } diff --git a/src/utils.rs b/src/utils.rs index 02d2b623..65796eb6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -34,3 +34,56 @@ where OneOrMany::Many(values) => Ok(values), } } + +/// Serde helper to deserialize value from environment variable. +pub fn from_env<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let env_name: Option = Option::deserialize(deserializer)?; + if let Some(env_name) = env_name { + let value = std::env::var(&env_name) + .map_err(|_| serde::de::Error::custom(format!("env var `{env_name}` not found")))?; + Ok(Some(value)) + } else { + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use serde::Deserialize; + use serde_json::json; + + use super::from_env; + + #[derive(Debug, Deserialize)] + pub struct Config { + #[serde(default, deserialize_with = "from_env")] + pub api_token: Option, + } + + #[test] + fn test_from_env() -> Result<(), Box> { + // Test no value + let config: Config = serde_json::from_value(json!({}))?; + assert_eq!(config.api_token, None); + + // Test invalid value + let config: Result = serde_json::from_value(json!({ + "api_token": "DOES_NOT_EXIST" + })); + assert!(config.is_err_and(|err| err.to_string() == "env var `DOES_NOT_EXIST` not found")); + + // Test valid value + unsafe { + std::env::set_var("CLIENT_API_TOKEN", "token"); + } + let config: Config = serde_json::from_value(json!({ + "api_token": "CLIENT_API_TOKEN" + }))?; + assert_eq!(config.api_token, Some("token".into())); + + Ok(()) + } +}