Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ pub async fn create_http_client(
.layer(http_trace_layer())
.layer(TimeoutLayer::new(request_timeout))
.service(client);
Ok(HttpClient::new(base_url, client))
Ok(HttpClient::new(
base_url,
service_config.api_token.clone(),
client,
))
}

pub async fn create_grpc_client<C: Debug + Clone>(
Expand Down
22 changes: 20 additions & 2 deletions src/clients/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,17 @@ pub trait HttpClientExt: Client {
pub struct HttpClient {
base_url: Url,
health_url: Url,
api_token: Option<String>,
inner: HttpClientInner,
}

impl HttpClient {
pub fn new(base_url: Url, inner: HttpClientInner) -> Self {
pub fn new(base_url: Url, api_token: Option<String>, inner: HttpClientInner) -> Self {
let health_url = base_url.join("health").unwrap();
Self {
base_url,
health_url,
api_token,
inner,
}
}
Expand All @@ -140,6 +142,19 @@ impl HttpClient {
self.base_url.join(path).unwrap()
}

/// Injects the API token as a Bearer token in the Authorization header if configured and present in the environment.
fn inject_api_token(&self, headers: &mut HeaderMap) -> Result<(), Error> {
if let Some(token) = &self.api_token {
headers.insert(
http::header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", token)).map_err(|e| Error::Http {
code: StatusCode::INTERNAL_SERVER_ERROR,
message: format!("invalid authorization header: {e}"),
})?,
);
}
Ok(())
}
pub async fn get(
&self,
url: Url,
Expand All @@ -162,11 +177,14 @@ impl HttpClient {
&self,
url: Url,
method: Method,
headers: HeaderMap,
mut headers: HeaderMap,
body: impl RequestBody,
) -> Result<Response, Error> {
let ctx = Span::current().context();

self.inject_api_token(&mut headers)?;
let headers = trace::with_traceparent_header(&ctx, headers);

let mut builder = hyper::http::request::Builder::new()
.method(method)
.uri(url.as_uri());
Expand Down
9 changes: 8 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::{
use serde::Deserialize;
use tracing::{debug, error, info, warn};

use crate::clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname};
use crate::{
clients::{chunker::DEFAULT_CHUNKER_ID, is_valid_hostname},
utils::from_env,
};

/// Default allowed headers to passthrough to clients.
const DEFAULT_ALLOWED_HEADERS: &[&str] = &[];
Expand Down Expand Up @@ -86,6 +89,9 @@ pub struct ServiceConfig {
pub http2_keep_alive_interval: Option<u64>,
/// Keep-alive timeout in seconds for client calls [currently only for grpc generation]
pub keep_alive_timeout: Option<u64>,
/// Name of environment variable that contains the API key to use for this service [currently only for http generation]
#[serde(default, deserialize_with = "from_env")]
pub api_token: Option<String>,
}

impl ServiceConfig {
Expand All @@ -101,6 +107,7 @@ impl ServiceConfig {
max_retries: None,
http2_keep_alive_interval: None,
keep_alive_timeout: None,
api_token: None,
}
}
}
Expand Down
53 changes: 53 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,56 @@ where
OneOrMany::Many(values) => Ok(values),
}
}

/// Serde helper to deserialize value from environment variable.
pub fn from_env<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let env_name: Option<String> = Option::deserialize(deserializer)?;
if let Some(env_name) = env_name {
let value = std::env::var(&env_name)
.map_err(|_| serde::de::Error::custom(format!("env var `{env_name}` not found")))?;
Ok(Some(value))
} else {
Ok(None)
}
}

#[cfg(test)]
mod tests {
use serde::Deserialize;
use serde_json::json;

use super::from_env;

#[derive(Debug, Deserialize)]
pub struct Config {
#[serde(default, deserialize_with = "from_env")]
pub api_token: Option<String>,
}

#[test]
fn test_from_env() -> Result<(), Box<dyn std::error::Error>> {
// Test no value
let config: Config = serde_json::from_value(json!({}))?;
assert_eq!(config.api_token, None);

// Test invalid value
let config: Result<Config, serde_json::error::Error> = serde_json::from_value(json!({
"api_token": "DOES_NOT_EXIST"
}));
assert!(config.is_err_and(|err| err.to_string() == "env var `DOES_NOT_EXIST` not found"));

// Test valid value
unsafe {
std::env::set_var("CLIENT_API_TOKEN", "token");
}
let config: Config = serde_json::from_value(json!({
"api_token": "CLIENT_API_TOKEN"
}))?;
assert_eq!(config.api_token, Some("token".into()));

Ok(())
}
}