Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ tracing = "0.1.41"
tracing-subscriber = "0.3.20"

[dev-dependencies]
http-body-util = "0.1.3"
httpmock = "0.8.2"
tower = "0.5.2"
10 changes: 7 additions & 3 deletions src/clients.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::borrow::Cow;
use std::fmt;

use axum::http::HeaderMap;
#[cfg(test)]
use httpmock::MockServer;
use reqwest::header::HeaderMap;
use reqwest::{Client, Url};
use serde::de::DeserializeOwned;
use tracing::{debug, info, instrument};
Expand Down Expand Up @@ -58,17 +58,19 @@ impl TiledClient {
pub async fn search(
&self,
path: &str,
headers: Option<HeaderMap>,
query: &[(&str, Cow<'_, str>)],
) -> ClientResult<node::Root> {
self.request(&format!("api/v1/search/{}", path), None, Some(query))
self.request(&format!("api/v1/search/{}", path), headers, Some(query))
.await
}
pub async fn table_full(
&self,
path: &str,
columns: Option<Vec<String>>,
headers: Option<HeaderMap>,
) -> ClientResult<table::Table> {
let mut headers = HeaderMap::new();
let mut headers = headers.unwrap_or_default();
headers.insert("accept", "application/json".parse().unwrap());
let query = columns.map(|columns| {
columns
Expand All @@ -91,6 +93,7 @@ impl TiledClient {
stream: String,
det: String,
id: u32,
headers: Option<HeaderMap>,
) -> reqwest::Result<reqwest::Response> {
let mut url = self
.address
Expand All @@ -105,6 +108,7 @@ impl TiledClient {
debug!("Downloading id={id} from {url}");
self.client
.get(url)
.headers(headers.unwrap_or_default())
.query(&[("id", &id.to_string())])
.send()
.await
Expand Down
106 changes: 100 additions & 6 deletions src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,127 @@
use async_graphql::http::GraphiQLSource;
use async_graphql::*;
use async_graphql::{EmptyMutation, EmptySubscription, Schema};
use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
use axum::Extension;
use axum::body::Body;
use axum::extract::{Path, State};
use axum::http::{HeaderMap, StatusCode};
use axum::extract::{OptionalFromRequestParts, Path, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{Html, IntoResponse};
use reqwest::header::AUTHORIZATION;
use tracing::info;

use crate::clients::TiledClient;
use crate::model::TiledQuery;

pub async fn graphql_handler(
auth_token: Option<AuthHeader>,
schema: Extension<Schema<TiledQuery, EmptyMutation, EmptySubscription>>,
req: GraphQLRequest,
) -> GraphQLResponse {
let query = req.into_inner().query;
schema.execute(query).await.into()
schema
.execute(req.into_inner().data(auth_token))
.await
.into()
}

pub async fn graphiql_handler() -> impl IntoResponse {
Html(GraphiQLSource::build().endpoint("/graphql").finish())
}

pub async fn download_handler(
auth: Option<AuthHeader>,
State(client): State<TiledClient>,
Path((run, stream, det, id)): Path<(String, String, String, u32)>,
) -> (StatusCode, HeaderMap, Body) {
info!("Downloading {run}/{stream}/{det}/{id}");
let req = client.download(run, stream, det, id).await;
let headers = auth.as_ref().map(AuthHeader::as_header_map);
let req = client.download(run, stream, det, id, headers).await;
crate::download::forward_download_response(req).await
}

/// Extractor to accept an un-typed Authorization header (can be ApiKey/Bearer/Basic etc), and
/// make it accessible as a HeaderValue to be forwarded rather than extracted into something to use
/// locally (as the TypedHeader equivalent does).
pub struct AuthHeader(HeaderValue);

impl AuthHeader {
pub fn as_header_map(&self) -> HeaderMap {
[(AUTHORIZATION, self.0.clone())].into_iter().collect()
}
}

#[cfg(test)]
impl From<HeaderValue> for AuthHeader {
fn from(value: HeaderValue) -> Self {
Self(value)
}
}

impl<S> OptionalFromRequestParts<S> for AuthHeader
where
S: Send + Sync,
{
type Rejection = ();

async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts
.headers
.get("Authorization")
.map(|value| Self(value.clone())))
}
}

#[cfg(test)]
mod tests {
use axum::Router;
use axum::body::Body;
use axum::http::Request;
use axum::response::IntoResponse;
use axum::routing::get;
use http_body_util::BodyExt as _;
use tower::ServiceExt;

use super::AuthHeader;

async fn auth_echo(auth: Option<AuthHeader>) -> impl IntoResponse {
match auth {
Some(header) => header.0.to_str().unwrap().to_owned(),
None => "No auth".to_owned(),
}
}
fn app() -> Router {
Router::new().route("/", get(auth_echo))
}
#[tokio::test]
async fn auth_extract() {
let app = app();
let response = app
.oneshot(
Request::builder()
.uri("/")
.header("Authorization", "auth_value")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.into_body().collect().await.unwrap().to_bytes(),
"auth_value"
);
}
#[tokio::test]
async fn no_auth_extract() {
let app = app();
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(
response.into_body().collect().await.unwrap().to_bytes(),
"No auth"
);
}
}
52 changes: 50 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use serde_json::Value;
use tracing::{info, instrument};

use crate::clients::TiledClient;
use crate::handlers::AuthHeader;
use crate::model::node::NodeAttributes;

pub(crate) struct TiledQuery;
Expand All @@ -40,10 +41,13 @@ impl InstrumentSession {
&self.name
}
async fn runs(&self, ctx: &Context<'_>) -> Result<Vec<Run>> {
let auth = ctx.data::<Option<AuthHeader>>()?;
let headers = auth.as_ref().map(AuthHeader::as_header_map);
let root = ctx
.data::<TiledClient>()?
.search(
"",
headers,
&[
(
"filter[eq][condition][key]",
Expand Down Expand Up @@ -132,6 +136,8 @@ impl TableData {
ctx: &Context<'_>,
columns: Option<Vec<String>>,
) -> Result<HashMap<String, Vec<Value>>> {
let auth = ctx.data::<Option<AuthHeader>>()?;
let headers = auth.as_ref().map(AuthHeader::as_header_map);
let client = ctx.data::<TiledClient>()?;
let p = self
.attrs
Expand All @@ -143,7 +149,7 @@ impl TableData {
.join("/");
info!("path: {:?}", p);

let table_data = client.table_full(&p, columns).await?;
let table_data = client.table_full(&p, columns, headers).await?;
Ok(table_data)
}
}
Expand All @@ -165,15 +171,22 @@ impl Run {
&self.data.id
}
async fn data(&self, ctx: &Context<'_>) -> Result<Vec<RunData<'_>>> {
let auth = ctx.data::<Option<AuthHeader>>()?;
let headers = auth.as_ref().map(AuthHeader::as_header_map);
let client = ctx.data::<TiledClient>()?;
let run_data = client
.search(&self.data.id, &[("include_data_sources", "true".into())])
.search(
&self.data.id,
headers.clone(),
&[("include_data_sources", "true".into())],
)
.await?;
let mut sources = Vec::new();
for stream in run_data.data {
let stream_data = client
.search(
&format!("{}/{}", self.data.id, stream.id),
headers.clone(),
&[("include_data_sources", "true".into())],
)
.await?;
Expand All @@ -200,10 +213,13 @@ impl Run {
#[cfg(test)]
mod tests {
use async_graphql::{EmptyMutation, EmptySubscription, Schema, value};
use axum::http::HeaderValue;
use httpmock::MockServer;
use serde_json::json;

use crate::TiledQuery;
use crate::clients::TiledClient;
use crate::handlers::AuthHeader;

fn build_schema(url: &str) -> Schema<TiledQuery, EmptyMutation, EmptySubscription> {
Schema::build(TiledQuery, EmptyMutation, EmptySubscription)
Expand All @@ -228,4 +244,36 @@ mod tests {
assert_eq!(response.errors, &[]);
mock.assert();
}

#[tokio::test]
async fn auth_forwarding() {
let server = MockServer::start();
let mock_instrument_session = server
.mock_async(|when, then| {
when.method("GET")
.path("/api/v1/search/")
.query_param("filter[eq][condition][key]", "start.instrument_session")
.query_param("filter[eq][condition][value]", r#""cm12345-6""#)
.header("Authorization", "auth_value");
then.status(200).json_body(json!({
"data": [],
"error": null,
"links": {"self":""},
"meta": {}
}));
})
.await;
let schema = Schema::build(TiledQuery, EmptyMutation, EmptySubscription)
.data(TiledClient::new(server.base_url().parse().unwrap()))
.data(Some(AuthHeader::from(HeaderValue::from_static(
"auth_value",
))))
.finish();
let response = schema
.execute(r#"{ instrumentSession(name: "cm12345-6"){ runs { id }}}"#)
.await;
assert_eq!(response.errors, &[]);
assert_eq!(response.data, value!({"instrumentSession": {"runs": []}}));
mock_instrument_session.assert();
}
}
Loading