Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ pub enum S3Error {
InvalidContentDisposition,
#[error("Invalid Content-Disposition")]
MissingBody,
#[error("Missing header {0}")]
MissingHeader(&'static str),
#[error("Invalid header {0}")]
InvalidHeader(&'static str),
}

#[derive(Error, Debug)]
Expand Down
76 changes: 76 additions & 0 deletions client/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,61 @@ impl TryFrom<Response> for S3GetResponse {
}
}

pub struct S3HeadResponse {
pub location: String,
pub last_modified: String,
pub size: u64,
pub e_tag: Option<String>,
pub version: Option<String>,
}

impl TryFrom<Response> for S3HeadResponse {
type Error = S3Error;

fn try_from(response: Response) -> Result<Self, Self::Error> {
let headers = response.headers();

let last_modified = headers
.get(http::header::LAST_MODIFIED)
.ok_or(S3Error::MissingHeader("Last-Modified"))?
.to_str()
.map_err(|_| S3Error::InvalidHeader("Last-Modified"))?
.to_string();

let size = headers
.get(http::header::CONTENT_LENGTH)
.ok_or(S3Error::MissingHeader("Content-Length"))?
.to_str()
.map_err(|_| S3Error::InvalidHeader("Content-Length"))?
.parse::<u64>()
.map_err(|_| S3Error::InvalidHeader("Content-Length"))?;

let e_tag = headers
.get(http::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(|v| v.trim_matches('"').to_string());

let version = headers
.get("x-amz-version-id")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());

let location = response
.extensions()
.get::<http::Uri>()
.map(|uri| uri.path().to_string())
.unwrap_or_default();

Ok(Self {
location,
last_modified,
size,
e_tag,
version,
})
}
}

impl Client {
pub fn s3_layer<L, E>(&mut self, layer: L) -> &mut Self
where
Expand Down Expand Up @@ -131,6 +186,27 @@ impl Client {
check_status(&res.status())?;
Ok(res.try_into()?)
}

pub async fn s3_head(&self, url: Url) -> Result<S3HeadResponse> {
self.validate_host(&url)?;

let body = Body::from(r#"{"head": true}"#.to_string());

let mut request = http::Request::builder()
.method(http::Method::GET)
.uri(url.to_string());

if let Some(content_length) = body.content_length() {
if let Some(headers) = request.headers_mut() {
headers.insert(reqwest::header::CONTENT_LENGTH, content_length.into());
}
}

let request = request.body(body)?;
let res = self.s3_service().oneshot(request).await?;
check_status(&res.status())?;
Ok(res.try_into()?)
}
}

pub struct S3Range {
Expand Down
3 changes: 3 additions & 0 deletions src/commands/dataset/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub struct GetDatasetBySlug;
pub struct GetDatasetSlugResponse {
pub id: String,
pub viewer_can_create_version: bool,
pub viewer_can_read_dataset_version_file: bool,
}

pub async fn get_dataset_by_slug(
Expand Down Expand Up @@ -60,11 +61,13 @@ pub async fn get_dataset_by_slug(
return Ok(GetDatasetSlugResponse {
id: new_dataset.id,
viewer_can_create_version: new_dataset.viewer_can_create_version,
viewer_can_read_dataset_version_file: new_dataset.viewer_can_read_dataset_version_file,
});
};

Ok(GetDatasetSlugResponse {
id: dataset.id,
viewer_can_create_version: dataset.viewer_can_create_version,
viewer_can_read_dataset_version_file: dataset.viewer_can_read_dataset_version_file,
})
}
221 changes: 221 additions & 0 deletions src/commands/dataset/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use clap::Args;
use futures::{stream, StreamExt, TryStreamExt};
use graphql_client::GraphQLQuery;
use indicatif::{HumanBytes, MultiProgress, ProgressBar};
use serde::Serialize;
use std::path::PathBuf;
use url::Url;

use crate::{
commands::{
dataset::{
common::{get_dataset_by_slug, DatasetCommonArgs},
download::get_dataset_version_files::GetDatasetVersionFilesNodeOnDatasetVersionFilesNodes,
version::common::get_dataset_version,
},
GlobalArgs,
},
download::{multipart_download, MultipartOptions},
error::{self, Result},
};

#[derive(Args, Debug, Serialize)]
pub struct Download {
#[command(flatten)]
common: DatasetCommonArgs,
#[arg(short, long)]
version: semver::Version,
#[arg(short, long)]
destination: PathBuf,
#[command(flatten)]
options: MultipartOptions,
#[clap(long, default_value_t = 10)]
concurrency: usize,
}

#[derive(GraphQLQuery)]
#[graphql(
query_path = "src/graphql/get_dataset_version_files.graphql",
schema_path = "schema.graphql",
response_derives = "Debug"
)]
pub struct GetDatasetVersionFiles;

#[derive(GraphQLQuery)]
#[graphql(
query_path = "src/graphql/get_dataset_version_file_by_partition.graphql",
schema_path = "schema.graphql",
response_derives = "Debug"
)]
pub struct GetDatasetVersionFileByPartition;

pub async fn download(args: Download, global: GlobalArgs) -> Result<()> {
let m = MultiProgress::new();

let client = global.graphql_client().await?;

let (owner, local_slug) = args.common.slug_pair()?;
let multipart_options = args.options;

let dataset = get_dataset_by_slug(&global, owner, local_slug).await?;
if !dataset.viewer_can_read_dataset_version_file {
return Err(error::user(
"Permission denied",
"Cannot read dataset files",
));
}

let dataset_version = get_dataset_version(
&client,
dataset.id,
args.version.major as _,
args.version.minor as _,
args.version.patch as _,
)
.await?
.ok_or_else(|| error::user("Not found", "Dataset version not found"))?;

let response = client
.send::<GetDatasetVersionFiles>(get_dataset_version_files::Variables {
dataset_version_id: dataset_version.id,
})
.await?;

let dataset_version_files = match response.node {
get_dataset_version_files::GetDatasetVersionFilesNode::DatasetVersion(v) => v,
_ => {
return Err(error::system(
"Invalid node type",
"Unexpected GraphQL response",
))
}
};

let nodes = dataset_version_files.files.nodes;
let dataset_name = dataset_version_files.dataset.name;

let dataset_dir = args.destination.join(&dataset_name);
tokio::fs::create_dir_all(&dataset_dir).await?;

let total_size = dataset_version.size as u64;
let total_files = nodes.len();

let overall_progress = m.add(global.spinner().with_message(format!(
"Downloading '{}' ({} files, {})",
dataset_name,
total_files,
HumanBytes(total_size)
)));

stream::iter(nodes)
.map(|node| {
let client = client.to_owned();
let m = m.to_owned();
let multipart_options = multipart_options.to_owned();
let dataset_dir = dataset_dir.to_owned();
let dataset_name = dataset_name.to_owned();

async move {
download_partition_file(
&m,
&client,
&multipart_options,
&dataset_dir,
&dataset_name,
node,
)
.await
}
})
.buffer_unordered(args.concurrency)
.try_collect::<()>()
.await?;

overall_progress.finish_with_message("Done");

Ok(())
}

async fn download_partition_file(
m: &MultiProgress,
client: &aqora_client::Client,
multipart_options: &MultipartOptions,
output_dir: &std::path::Path,
dataset_name: &str,
file_node: GetDatasetVersionFilesNodeOnDatasetVersionFilesNodes,
) -> Result<()> {
let mut client = client.clone();
client.s3_layer(aqora_client::checksum::S3ChecksumLayer::new(
aqora_client::checksum::crc32fast::Crc32::new(),
));

let (metadata, url) = match client.s3_head(file_node.url.clone()).await {
Ok(metadata) => (metadata, file_node.url.clone()),
// retry if presigned url expired due to long dataset download time
Err(e) => {
tracing::warn!(error = %e, "Retrying: failed to fetch object header");
let response = client
.send::<GetDatasetVersionFileByPartition>(
get_dataset_version_file_by_partition::Variables {
dataset_version_id: file_node.dataset_version.id,
partition_num: file_node.partition_num,
},
)
.await?;

let dataset_version_file = match response.node {
get_dataset_version_file_by_partition::GetDatasetVersionFileByPartitionNode::DatasetVersion(v) => v,
_ => {
return Err(error::system(
"Invalid node type",
"Unexpected GraphQL response",
));
}
};
let file_by_partition_num = match dataset_version_file.file_by_partition_num {
Some(file_by_partition_num) => file_by_partition_num,
None => {
return Err(error::system(
"Invalid partition number",
"The partition does not exist",
))
}
};

let file_url = file_by_partition_num.url;
(client.s3_head(file_url.clone()).await?, file_url)
}
};

let filename = format!("{}-{}.parquet", dataset_name, file_node.partition_num);
let output_path = output_dir.join(&filename);

if let Ok(existing) = tokio::fs::metadata(&output_path).await {
if existing.len() == metadata.size {
return Ok(());
}
}

tokio::fs::create_dir_all(output_path.parent().unwrap()).await?;

let temp = tempfile::NamedTempFile::new_in(output_dir)?;
let temp_path = temp.path().to_owned();

let pb = m.add(ProgressBar::new_spinner());
pb.set_message(filename);

multipart_download(
&client,
metadata.size,
url,
multipart_options,
&temp_path,
&pb,
)
.await?;

pb.finish_and_clear();
tokio::fs::rename(&temp_path, &output_path).await?;

Ok(())
}
4 changes: 4 additions & 0 deletions src/commands/dataset/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod common;
mod convert;
mod download;
mod infer;
mod new;
mod upload;
Expand All @@ -13,6 +14,7 @@ use crate::commands::GlobalArgs;
use crate::error::Result;

use convert::{convert, Convert};
use download::{download, Download};
use infer::{infer, Infer};
use new::{new, New};
use upload::{upload, Upload};
Expand All @@ -26,6 +28,7 @@ pub enum Dataset {
Convert(Convert),
New(New),
Upload(Upload),
Download(Download),
Version {
#[command(subcommand)]
args: Version,
Expand All @@ -38,6 +41,7 @@ pub async fn dataset(args: Dataset, global: GlobalArgs) -> Result<()> {
Dataset::Convert(args) => convert(args, global).await,
Dataset::New(args) => new(args, global).await,
Dataset::Upload(args) => upload(args, global).await,
Dataset::Download(args) => download(args, global).await,
Dataset::Version { args } => version(args, global).await,
}
}
2 changes: 1 addition & 1 deletion src/commands/global_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::path::PathBuf;
use url::Url;

lazy_static::lazy_static! {
static ref DEFAULT_PARALLELISM: usize = std::thread::available_parallelism()
pub static ref DEFAULT_PARALLELISM: usize = std::thread::available_parallelism()
.map(usize::from)
.unwrap_or(1);
}
Expand Down
Loading
Loading