Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 104 additions & 13 deletions crates/utils/src/net.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::fs::{self, FsError};
use async_trait::async_trait;
use reqwest::{Client, Response};
use reqwest::{
Client, Response,
header::{HeaderMap, HeaderName, HeaderValue},
};
use std::collections::HashMap;
use std::cmp;
use std::fmt::Debug;
use std::io::Write;
Expand All @@ -26,19 +30,59 @@ pub type BoxedDownloader = Box<dyn Downloader>;
#[derive(Default)]
pub struct DefaultDownloader {
client: reqwest::Client,
headers: Option<HeaderMap>,
}

fn headers_from_map(map: &HashMap<String, String>) -> Result<HeaderMap, NetError> {
let mut headers = HeaderMap::new();
for (k, v) in map {
let name = HeaderName::try_from(k.as_str()).map_err(|e| NetError::HttpUnknown {
url: String::new(),
error: format!("invalid header name {:?}: {}", k, e),
})?;
let value = HeaderValue::try_from(v.as_str()).map_err(|e| NetError::HttpUnknown {
url: String::new(),
error: format!("invalid header value {:?}: {}", v, e),
})?;
headers.insert(name, value);
}
Ok(headers)
}

impl DefaultDownloader {
/// Create a downloader that applies the given headers to each request.
pub fn new_with_headers(headers: HashMap<String, String>) -> Result<Self, NetError> {
Ok(Self {
client: Client::new(),
headers: Some(headers_from_map(&headers)?),
})
}

/// Create a downloader with a custom client and headers.
pub fn new_with_client_and_headers(
client: Client,
headers: HashMap<String, String>,
) -> Result<Self, NetError> {
Ok(Self {
client,
headers: Some(headers_from_map(&headers)?),
})
}
}

#[async_trait]
impl Downloader for DefaultDownloader {
async fn download(&self, url: Url) -> Result<Response, NetError> {
self.client
.get(url.clone())
.send()
.await
.map_err(|error| NetError::Http {
error: Box::new(error),
url: url.to_string(),
})
let mut request = self.client.get(url.clone());

if let Some(headers) = &self.headers {
request = request.headers(headers.clone());
}

request.send().await.map_err(|error| NetError::Http {
error: Box::new(error),
url: url.to_string(),
})
}
}

Expand All @@ -62,9 +106,12 @@ pub async fn download_from_url_with_options<S: AsRef<str> + Debug, D: AsRef<Path
) -> Result<(), NetError> {
let source_url = source_url.as_ref();
let dest_file = dest_file.as_ref();
let downloader = options
.downloader
.unwrap_or_else(|| Box::new(DefaultDownloader::default()));
let DownloadOptions {
downloader,
on_chunk,
} = options;

let downloader = downloader.unwrap_or_else(|| Box::new(DefaultDownloader::default()));

let handle_fs_error = |error: std::io::Error| FsError::Write {
path: dest_file.to_path_buf(),
Expand Down Expand Up @@ -110,7 +157,7 @@ pub async fn download_from_url_with_options<S: AsRef<str> + Debug, D: AsRef<Path
let mut file = fs::create_file(dest_file)?;

// Write the bytes in chunks
match options.on_chunk {
match on_chunk {
Some(on_chunk) => {
let total_size = response.content_length().unwrap_or(0);
let mut current_size: u64 = 0;
Expand Down Expand Up @@ -158,13 +205,39 @@ pub async fn download_from_url_with_client<S: AsRef<str> + Debug, D: AsRef<Path>
DownloadOptions {
downloader: Some(Box::new(DefaultDownloader {
client: client.to_owned(),
headers: None,
})),
on_chunk: None,
},
)
.await
}

/// Download a file from the provided source URL, to the destination file path,
/// using a custom `reqwest` [`Client`] and HTTP headers.
pub async fn download_from_url_with_client_and_headers<
S: AsRef<str> + Debug,
D: AsRef<Path> + Debug,
>(
source_url: S,
dest_file: D,
client: &Client,
headers: HashMap<String, String>,
) -> Result<(), NetError> {
download_from_url_with_options(
source_url,
dest_file,
DownloadOptions {
downloader: Some(Box::new(DefaultDownloader::new_with_client_and_headers(
client.to_owned(),
headers,
)?)),
on_chunk: None,
},
)
.await
}

/// Download a file from the provided source URL, to the destination file path.
pub async fn download_from_url<S: AsRef<str> + Debug, D: AsRef<Path> + Debug>(
source_url: S,
Expand All @@ -173,6 +246,24 @@ pub async fn download_from_url<S: AsRef<str> + Debug, D: AsRef<Path> + Debug>(
download_from_url_with_options(source_url, dest_file, DownloadOptions::default()).await
}

/// Download a file from the provided source URL, to the destination file path,
/// with HTTP headers.
pub async fn download_from_url_with_headers<S: AsRef<str> + Debug, D: AsRef<Path> + Debug>(
source_url: S,
dest_file: D,
headers: HashMap<String, String>,
) -> Result<(), NetError> {
download_from_url_with_options(
source_url,
dest_file,
DownloadOptions {
downloader: Some(Box::new(DefaultDownloader::new_with_headers(headers)?)),
..Default::default()
},
)
.await
}

mod offline {
use super::*;

Expand Down
77 changes: 77 additions & 0 deletions crates/utils/tests/net_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
use starbase_sandbox::create_empty_sandbox;
use starbase_utils::net;
use std::collections::HashMap;

fn headers_with_custom_header() -> HashMap<String, String> {
let mut headers = HashMap::new();
headers.insert(
"X-Proto-Test-Header".to_string(),
"proto-starbase-headers-test".to_string(),
);
headers
}

mod download {
use super::*;
Expand Down Expand Up @@ -52,4 +62,71 @@ mod download {
assert!(dest_file.exists());
assert_ne!(dest_file.metadata().unwrap().len(), 0);
}

#[tokio::test]
async fn downloads_with_headers() {
let sandbox = create_empty_sandbox();
let dest_file = sandbox.path().join("headers.json");

net::download_from_url_with_headers(
"https://httpbin.org/headers",
&dest_file,
headers_with_custom_header(),
)
.await
.unwrap();

let body = std::fs::read_to_string(&dest_file).unwrap();
assert!(
body.contains("proto-starbase-headers-test"),
"expected response to contain custom header value, got: {body}"
);
}

#[tokio::test]
async fn downloads_with_client_and_headers() {
let sandbox = create_empty_sandbox();
let dest_file = sandbox.path().join("headers.json");
let client = reqwest::Client::new();

net::download_from_url_with_client_and_headers(
"https://httpbin.org/headers",
&dest_file,
&client,
headers_with_custom_header(),
)
.await
.unwrap();

let body = std::fs::read_to_string(&dest_file).unwrap();
assert!(
body.contains("proto-starbase-headers-test"),
"expected response to contain custom header value, got: {body}"
);
}

#[tokio::test]
async fn downloads_with_options_headers() {
let sandbox = create_empty_sandbox();
let dest_file = sandbox.path().join("headers.json");

net::download_from_url_with_options(
"https://httpbin.org/headers",
&dest_file,
net::DownloadOptions {
downloader: Some(Box::new(
net::DefaultDownloader::new_with_headers(headers_with_custom_header()).unwrap(),
)),
..Default::default()
},
)
.await
.unwrap();

let body = std::fs::read_to_string(&dest_file).unwrap();
assert!(
body.contains("proto-starbase-headers-test"),
"expected response to contain custom header value, got: {body}"
);
}
}
Loading