Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions xet_pkg/examples/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::time::Duration;

use anyhow::Result;
use clap::{Parser, Subcommand};
use xet::xet_session::{FileMetadata, TaskHandle, TaskStatus, XetFileInfo, XetSessionBuilder};
use xet::xet_session::{FileMetadata, TaskStatus, XetFileInfo, XetSessionBuilder};

#[derive(Parser)]
#[clap(name = "session-demo", about = "XetSession API demo")]
Expand Down Expand Up @@ -58,7 +58,7 @@ fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {

// Enqueue all uploads; each starts immediately in the background.
let n_files = files.len();
let handles: Vec<TaskHandle> = files
let handles: Vec<_> = files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the sake of examples, can we avoid omitting types, this makes me have to go to the code to figure out all the return types I need to use, making the example less useful (still useful, but more useful with no explicit type omissions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for sure! Inserted types

.iter()
.map(|f| commit.upload_from_path(f.clone()))
.collect::<Result<_, _>>()?;
Expand All @@ -80,13 +80,17 @@ fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
});

// Block until all uploads finish and metadata is finalized.
let metadata: Vec<_> = commit.commit()?.into_iter().filter_map(|m| m.ok()).collect();
let results = commit.commit()?;

for m in &metadata {
for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) {
println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size);
}

// Persist metadata so it can be passed to the `download` subcommand.
let metadata: Vec<_> = results
.into_values()
.filter_map(|m| m.as_ref().as_ref().ok().cloned())
.collect();
std::fs::write("upload_metadata.json", serde_json::to_string_pretty(&metadata)?)?;

Ok(())
Expand All @@ -105,7 +109,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<

// Enqueue all downloads; each starts immediately in the background.
let n_files = metadata.len();
let handles: Vec<TaskHandle> = metadata
let handles: Vec<_> = metadata
.iter()
.map(|m| {
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
Expand Down Expand Up @@ -136,10 +140,12 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
});

// Block until all downloads finish.
let results: Vec<_> = group.finish()?.into_iter().filter_map(|m| m.ok()).collect();
let results = group.finish()?;

for r in &results {
println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size);
for (_task_id, result) in &results {
if let Ok(r) = result.as_ref() {
println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size);
}
}

Ok(())
Expand Down
172 changes: 117 additions & 55 deletions xet_pkg/src/xet_session/download_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, MutexGuard, RwLock};
use std::sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock};

use tokio::task::JoinHandle;
use ulid::Ulid;
Expand All @@ -11,7 +11,7 @@ use xet_runtime::core::XetRuntime;

use super::common::{GroupState, create_translator_config};
use super::errors::SessionError;
use super::progress::{GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus};
use super::progress::{DownloadTaskHandle, GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus};
use super::session::XetSession;

/// Groups related file downloads into a single unit of work.
Expand Down Expand Up @@ -89,7 +89,7 @@ impl DownloadGroup {
/// * `dest_path` – Local path where the downloaded file will be written. Parent directories are created
/// automatically.
///
/// Returns a [`TaskHandle`] that can be used to poll status and per-file
/// Returns a [`DownloadTaskHandle`] that can be used to poll status and per-file
/// progress without taking the GIL.
///
/// # Errors
Expand All @@ -101,7 +101,7 @@ impl DownloadGroup {
&self,
file_info: XetFileInfo,
dest_path: PathBuf,
) -> Result<TaskHandle, SessionError> {
) -> Result<DownloadTaskHandle, SessionError> {
self.session.check_alive()?;
self.inner.start_download_file_to_path(file_info, dest_path)
}
Expand All @@ -122,25 +122,41 @@ impl DownloadGroup {

/// Wait for all downloads to complete and return their results.
///
/// Blocks until every queued download finishes (or fails). Returns one
/// [`DownloadResult`] entry per download.
/// Blocks until every queued download finishes (or fails). Returns a
/// `HashMap` keyed by task ID (the [`Ulid`] returned by
/// [`download_file_to_path`](Self::download_file_to_path)), where each
/// value is [`DownloadResult`] (= `Arc<Result<`[`DownloadedFile`]`,
/// `[`SessionError`](crate::SessionError)`>>`). A single failed download
/// does not prevent the others from being collected.
///
/// Per-task results can also be read directly from the
/// [`DownloadTaskHandle`] returned by `download_file_to_path` via
/// [`result`](DownloadTaskHandle::result) after this method returns.
///
/// Consumes `self` — subsequent calls on any clone will return
/// [`SessionError::AlreadyFinished`] (or a channel-closed error if the
/// background worker has already exited).
pub fn finish(self) -> Result<Vec<Result<DownloadResult, SessionError>>, SessionError> {
pub fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
let inner = self.inner.clone();
self.session
.runtime
.external_run_async_task(async move { inner.handle_finish().await })?
}
}

/// Per-file result type returned by [`DownloadGroup::finish`].
///
/// The `Arc` lets the same value be stored in both the `finish()` return map
/// and the per-task [`DownloadTaskHandle`] without requiring the inner
/// `Result` to be `Clone`.
pub type DownloadResult = Arc<Result<DownloadedFile, SessionError>>;

/// Handle for a single download task tracked internally by DownloadGroup.
pub(crate) struct InnerDownloadTaskHandle {
status: Arc<Mutex<TaskStatus>>,
dest_path: PathBuf,
join_handle: JoinHandle<Result<XetFileInfo, SessionError>>,
result: Arc<OnceLock<DownloadResult>>,
}

/// All shared state owned by a single DownloadGroup instance.
Expand Down Expand Up @@ -214,7 +230,7 @@ impl DownloadGroupInner {
self: &Arc<Self>,
file_info: XetFileInfo,
dest_path: PathBuf,
) -> Result<TaskHandle, SessionError> {
) -> Result<DownloadTaskHandle, SessionError> {
// Hold the state lock guard for the duration of this function so finish() will not run
// when a download task is registering.
let state = self.state.lock()?;
Expand All @@ -223,10 +239,14 @@ impl DownloadGroupInner {
let tracking_id = Ulid::new();
let status = Arc::new(Mutex::new(TaskStatus::Queued));

let task_handle = TaskHandle {
status: Some(status.clone()),
group_progress: self.progress.clone(),
tracking_id,
let result: Arc<OnceLock<DownloadResult>> = Arc::new(OnceLock::new());
let task_handle = DownloadTaskHandle {
inner: TaskHandle {
status: Some(status.clone()),
group_progress: self.progress.clone(),
task_id: tracking_id,
},
result: result.clone(),
};

let Some(download_session) = self.download_session.lock()?.clone() else {
Expand All @@ -245,6 +265,7 @@ impl DownloadGroupInner {
status,
dest_path,
join_handle,
result,
};

self.active_tasks.write()?.insert(tracking_id, handle);
Expand All @@ -253,7 +274,7 @@ impl DownloadGroupInner {
}

/// Handle a `Finish` command from the public API.
async fn handle_finish(self: &Arc<Self>) -> Result<Vec<Result<DownloadResult, SessionError>>, SessionError> {
async fn handle_finish(self: &Arc<Self>) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
// Mark as not accepting new tasks
{
let mut state_guard = self.state.lock()?;
Expand All @@ -266,19 +287,27 @@ impl DownloadGroupInner {
// Wait for all downloads to complete
let active_tasks = std::mem::take(&mut *self.active_tasks.write()?);

let mut results = Vec::new();
let mut results = HashMap::new();
let mut join_err = None;
// Join all tasks first and then propogate errors.
for (_task_id, handle) in active_tasks {
for (task_id, handle) in active_tasks {
match handle.join_handle.await.map_err(SessionError::from) {
Ok(Ok(file_info)) => {
results.push(Ok(DownloadResult {
let result = Arc::new(Ok(DownloadedFile {
dest_path: handle.dest_path,
file_info,
}));
results.insert(task_id, result.clone());
// Update result to the external task handle, this is the only place setting
// the result, so no error will happen.
let _ = handle.result.set(result);
},
Ok(Err(task_err)) => {
results.push(Err(task_err));
let result: Arc<Result<DownloadedFile, SessionError>> = Arc::new(Err(task_err));
results.insert(task_id, result.clone());
// Update result to the external task handle, this is the only place setting
// the result, so no error will happen.
let _ = handle.result.set(result);
},
Err(e) => {
if join_err.is_none() {
Expand Down Expand Up @@ -316,30 +345,9 @@ impl DownloadGroupInner {
}
}

/// A progress snapshot for a single queued download.
///
/// Returned by [`DownloadGroup::get_progress`].
#[derive(Clone, Debug)]
pub struct DownloadProgress {
/// Unique identifier for this download task.
pub task_id: Ulid,
/// Local path where the file will be written.
pub dest_path: PathBuf,
/// Content-addressed hash of the file being downloaded.
pub file_hash: String,
/// Number of bytes downloaded so far.
pub bytes_completed: u64,
/// Total file size in bytes (0 if not yet known).
pub bytes_total: u64,
/// Current lifecycle state of the task.
pub status: TaskStatus,
/// Instantaneous download throughput in bytes per second.
pub speed_bps: f64,
}

/// Per-file result returned by [`DownloadGroup::finish`].
#[derive(Clone, Debug)]
pub struct DownloadResult {
pub struct DownloadedFile {
/// Local path where the file was written.
pub dest_path: PathBuf,
/// Xet file hash and size of the downloaded file.
Expand All @@ -353,6 +361,7 @@ mod tests {
use tempfile::{TempDir, tempdir};

use super::*;
use crate::xet_session::progress::UploadTaskHandle;
use crate::xet_session::session::XetSession;

fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
Expand All @@ -362,12 +371,12 @@ mod tests {

fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit()?;
commit.upload_bytes(data.to_vec(), Some(name.into()))?;
let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?;
let results = commit.commit()?;
let m = &results[0];
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
Ok(XetFileInfo {
hash: m.as_ref().unwrap().hash.clone(),
file_size: m.as_ref().unwrap().file_size,
hash: meta.hash.clone(),
file_size: meta.file_size,
})
}

Expand Down Expand Up @@ -544,28 +553,25 @@ mod tests {
let data_a = b"First file content";
let data_b = b"Second file content - different";

// Upload both files in one commit; use tracking_name to locate each result.
// Upload both files; capture handles so results can be retrieved by task_id.
let commit = session.new_upload_commit()?;
commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
let results = commit.commit()?;

let find_info = |name: &str| -> XetFileInfo {
let m = results
.iter()
.find(|r| r.as_ref().unwrap().tracking_name.as_deref() == Some(name))
.unwrap();
let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo {
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
XetFileInfo {
hash: m.as_ref().unwrap().hash.clone(),
file_size: m.as_ref().unwrap().file_size,
hash: meta.hash.clone(),
file_size: meta.file_size,
}
};

let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group()?;
group.download_file_to_path(find_info("a.bin"), dest_a.clone())?;
group.download_file_to_path(find_info("b.bin"), dest_b.clone())?;
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?;
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?;
group.finish()?;

assert_eq!(std::fs::read(&dest_a)?, data_a);
Expand Down Expand Up @@ -600,6 +606,62 @@ mod tests {
Ok(())
}

// ── Per-task result access patterns ──────────────────────────────────────
//
// After finish() completes there are two equivalent ways to retrieve a
// per-task DownloadResult:
//
// 1. HashMap lookup: `finish_results.get(&handle.task_id)`
// 2. Direct handle: `handle.result()` (on DownloadTaskHandle)

#[test]
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"result via task_id in finish map";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info, dest)?;
let results = group.finish()?;
let result = results.get(&handle.task_id).expect("task_id must be present in results");
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
Ok(())
}

#[test]
// DownloadTaskHandle::result() returns None before finish() is called.
fn test_download_result_none_before_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let file_info = upload_bytes(&session, b"some data", "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info, dest)?;
assert!(handle.result().is_none(), "result must be None before finish()");
group.finish()?;
Ok(())
}

#[test]
// DownloadTaskHandle::result() returns Some after finish() completes.
fn test_download_result_some_after_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"download result test data";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info.clone(), dest)?;
group.finish()?;
let result = handle.result().expect("result must be set after finish()");
let dl = result.as_ref().as_ref().unwrap();
assert_eq!(dl.file_info.file_size, data.len() as u64);
assert_eq!(dl.file_info.hash, file_info.hash);
Ok(())
}

// ── Mutex guard / concurrency test ───────────────────────────────────────
//
// `download_file_to_path` holds `self.state` for its entire execution so
Expand Down
Loading
Loading