diff --git a/Cargo.lock b/Cargo.lock index 65036d4..3b00d47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -355,6 +355,19 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "const-str" version = "1.1.0" @@ -618,6 +631,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -997,6 +1016,7 @@ dependencies = [ "futures", "globset", "hf-xet", + "indicatif", "libc", "owo-colors", "pathdiff", @@ -1232,6 +1252,19 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "ipnet" version = "2.12.0" @@ -1484,6 +1517,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "objc2-core-foundation" version = "0.3.2" @@ -1624,6 +1663,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.4" diff --git a/docs/superpowers/plans/2026-04-10-per-file-upload-progress.md b/docs/superpowers/plans/2026-04-10-per-file-upload-progress.md new file mode 100644 index 0000000..d07599d --- /dev/null +++ b/docs/superpowers/plans/2026-04-10-per-file-upload-progress.md @@ -0,0 +1,615 @@ +# Per-File Upload Progress Bars Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Show per-file progress bars during xet/LFS uploads instead of a single aggregate bar, mirroring the existing per-file download progress pattern. + +**Architecture:** Add a `Vec` field to `UploadEvent::Progress` for per-file data. In `xet.rs`, poll each `XetFileUpload` handle's `.progress()` in the 100ms loop, mapping `ItemProgressReport.item_name` back to `path_in_repo` via a `HashMap`. The CLI renderer creates/removes per-file indicatif bars with a cap of 10 visible, plus a summary line. Regular (non-LFS) files do not get progress bars. + +**Tech Stack:** Rust, indicatif, hf-xet 1.5.1 (`ItemProgressReport`, `XetFileUpload`) + +--- + +## File Structure + +| File | Action | Responsibility | +|------|--------|---------------| +| `huggingface_hub/src/types/progress.rs` | Modify | Add `files: Vec` to `UploadEvent::Progress` | +| `huggingface_hub/src/xet.rs` | Modify | Build item_name-to-repo-path map, poll per-file handles, emit per-file events | +| `huggingface_hub/src/api/files.rs` | Modify | Update all `UploadEvent::Progress` emit sites to include empty `files: vec![]` | +| `huggingface_hub/src/bin/hfrs/progress.rs` | Modify | Render per-file upload bars (max 10 visible, remove on complete, summary line) | + +--- + +### Task 1: Extend `UploadEvent::Progress` with per-file data + +**Files:** +- Modify: `huggingface_hub/src/types/progress.rs:27-42` + +- [ ] **Step 1: Write the failing test** + +Add a test in the `#[cfg(test)]` module that constructs an `UploadEvent::Progress` with a `files` field: + +```rust +#[test] +fn upload_progress_with_per_file_data() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 500, + total_bytes: 1000, + bytes_per_sec: Some(100.0), + files: vec![ + FileProgress { + filename: "model/weights.bin".to_string(), + bytes_completed: 300, + total_bytes: 600, + status: FileStatus::InProgress, + }, + FileProgress { + filename: "config.json".to_string(), + bytes_completed: 200, + total_bytes: 400, + status: FileStatus::InProgress, + }, + ], + }), + ); + + let events = handler.events(); + assert_eq!(events.len(), 1); + if let ProgressEvent::Upload(UploadEvent::Progress { files, .. }) = &events[0] { + assert_eq!(files.len(), 2); + assert_eq!(files[0].filename, "model/weights.bin"); + assert_eq!(files[1].filename, "config.json"); + } else { + panic!("expected Upload(Progress)"); + } +} +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cargo test -p huggingface-hub upload_progress_with_per_file_data` +Expected: FAIL — `UploadEvent::Progress` does not have a `files` field. + +- [ ] **Step 3: Add `files` field to `UploadEvent::Progress`** + +In `huggingface_hub/src/types/progress.rs`, change the `Progress` variant from: + +```rust +/// Aggregate byte-level progress during xet/LFS upload. +Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, +}, +``` + +to: + +```rust +/// Byte-level progress during xet/LFS upload. +/// `files` contains per-file progress for xet uploads (may be empty +/// for phases without per-file granularity). +Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + files: Vec, +}, +``` + +- [ ] **Step 4: Fix all existing emit sites that construct `UploadEvent::Progress`** + +Every existing construction of `UploadEvent::Progress` will now fail to compile because it's missing the `files` field. Add `files: vec![]` to each site: + +In `huggingface_hub/src/api/files.rs` — there are 3 emit sites (Preparing at ~line 1182, CheckingUploadMode at ~line 1195, Committing at ~line 1260). Add `files: vec![]` to each. + +In `huggingface_hub/src/xet.rs` — there are 2 emit sites (the poll loop at ~line 425, the final emit at ~line 444). Add `files: vec![]` to each for now (Task 2 will populate them). + +- [ ] **Step 5: Fix existing tests that pattern-match on `UploadEvent::Progress`** + +In `huggingface_hub/src/types/progress.rs`, the `upload_phase_progression` test constructs `UploadEvent::Progress` without `files`. Add `files: vec![]` to line ~160: + +```rust +ProgressEvent::Upload(UploadEvent::Progress { + phase: phase.clone(), + bytes_completed: 0, + total_bytes: 100, + bytes_per_sec: None, + files: vec![], +}), +``` + +Also fix the `emit_records_events` test around line ~158: + +```rust +ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 512, + total_bytes: 1024, + bytes_per_sec: Some(100.0), + files: vec![], +}), +``` + +- [ ] **Step 6: Run all tests to verify everything compiles and passes** + +Run: `cargo test -p huggingface-hub` +Expected: PASS (all existing tests + new test). + +- [ ] **Step 7: Run fmt and clippy** + +```bash +cargo +nightly fmt +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +- [ ] **Step 8: Commit** + +```bash +git add huggingface_hub/src/types/progress.rs huggingface_hub/src/api/files.rs huggingface_hub/src/xet.rs +git commit -m "feat: add per-file progress data to UploadEvent::Progress" +``` + +--- + +### Task 2: Emit per-file progress from xet upload polling loop + +**Files:** +- Modify: `huggingface_hub/src/xet.rs:400-450` + +- [ ] **Step 1: Write the failing test** + +This is an integration-level change in async code that requires a real xet session, so we can't easily unit test it in isolation. Instead, we'll verify correctness by: +1. Ensuring the code compiles with `--all-features` +2. The existing integration tests still pass +3. Manual verification with a real upload (documented at end of plan) + +Skip writing a new test for this task — the type system enforces correctness of the `FileProgress` construction, and Task 1's test already validates the event shape. + +- [ ] **Step 2: Build the `item_name` to `path_in_repo` mapping** + +In `xet_upload()` in `huggingface_hub/src/xet.rs`, replace the `task_ids_in_order` Vec with a `HashMap` that maps xet-core's `item_name` (the value `ItemProgressReport.item_name` will contain) to `path_in_repo`. + +For `AddSource::File(path)`: xet-core sets `item_name` to `std::path::absolute(path).to_str()`. Mimic this. +For `AddSource::Bytes(_)`: xet-core uses the `tracking_name` we pass. Currently we pass `None`. Pass `Some(path_in_repo.clone())` instead so `item_name` equals the repo path directly. + +Replace the loop at lines ~400-415: + +```rust +// Map from xet-core's item_name to path_in_repo. +// For File uploads, xet-core sets item_name to std::path::absolute(path). +// We mimic that logic here to build the reverse mapping. +// For Bytes uploads, we pass path_in_repo as tracking_name so item_name == path_in_repo. +let mut item_name_to_repo_path: HashMap = HashMap::with_capacity(files.len()); +let mut task_ids_in_order = Vec::with_capacity(files.len()); + +for (path_in_repo, source) in files { + tracing::info!(path = path_in_repo.as_str(), "queuing xet upload"); + let handle = match source { + AddSource::File(path) => { + // Mimic xet-core's item_name derivation: std::path::absolute(path).to_str() + // See xet-data upload_commit.rs XetUploadCommitInner::upload_from_path + if let Ok(abs) = std::path::absolute(path) { + if let Some(s) = abs.to_str() { + item_name_to_repo_path.insert(s.to_owned(), path_in_repo.clone()); + } + } + commit + .upload_from_path(path.clone(), Sha256Policy::Compute) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + AddSource::Bytes(bytes) => { + item_name_to_repo_path.insert(path_in_repo.clone(), path_in_repo.clone()); + commit + .upload_bytes(bytes.clone(), Sha256Policy::Compute, Some(path_in_repo.clone())) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + }; + task_ids_in_order.push(handle.task_id()); +} +``` + +Add `use std::collections::HashMap;` to the top of `xet.rs` if not already present. + +- [ ] **Step 3: Update the polling loop to emit per-file progress** + +Access the per-file upload handles via `commit`'s internal `file_handles` field. The `XetUploadCommit` stores file handles in a `Mutex>` — we need to read them in the poll loop. + +Check if `XetUploadCommit` exposes the file handles publicly. If not, we'll need to store our own `Vec` from the return values. + +Actually, `upload_from_path` and `upload_bytes` both return `XetFileUpload`. We already have the handles — we just don't keep them. Change the loop to collect them: + +```rust +let mut upload_handles: Vec<(String, XetFileUpload)> = Vec::with_capacity(files.len()); + +for (path_in_repo, source) in files { + tracing::info!(path = path_in_repo.as_str(), "queuing xet upload"); + let handle = match source { + AddSource::File(path) => { + if let Ok(abs) = std::path::absolute(path) { + if let Some(s) = abs.to_str() { + item_name_to_repo_path.insert(s.to_owned(), path_in_repo.clone()); + } + } + commit + .upload_from_path(path.clone(), Sha256Policy::Compute) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + AddSource::Bytes(bytes) => { + item_name_to_repo_path.insert(path_in_repo.clone(), path_in_repo.clone()); + commit + .upload_bytes(bytes.clone(), Sha256Policy::Compute, Some(path_in_repo.clone())) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + }; + task_ids_in_order.push(handle.task_id()); + upload_handles.push((path_in_repo.clone(), handle)); +} +``` + +Wait — `XetFileUpload` may not be `Clone` or `Send` in a way that lets us share it with the polling task. Let me check. Looking at the xet-core source, `XetFileUpload` wraps an `Arc` and a `TaskRuntime`, both of which are `Clone + Send + Sync`. And `XetFileUpload` itself has `.progress() -> Option` which is lock-free (atomic reads). So we can clone the handles and share them. + +Replace the poll loop at lines ~418-433: + +```rust +tracing::info!(file_count = files.len(), "committing xet uploads"); +let poll_handle = progress.as_ref().map(|handler| { + let handler = handler.clone(); + let commit = commit.clone(); + let handles = upload_handles.clone(); + let name_map = item_name_to_repo_path.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let report = commit.progress(); + + let mut file_progress: Vec = Vec::new(); + for (_repo_path, handle) in &handles { + if let Some(item_report) = handle.progress() { + let repo_path = name_map + .get(&item_report.item_name) + .cloned() + .unwrap_or(item_report.item_name.clone()); + let status = if item_report.bytes_completed == 0 { + FileStatus::Started + } else if item_report.bytes_completed >= item_report.total_bytes + && item_report.total_bytes > 0 + { + FileStatus::Complete + } else { + FileStatus::InProgress + }; + file_progress.push(FileProgress { + filename: repo_path, + bytes_completed: item_report.bytes_completed, + total_bytes: item_report.total_bytes, + status, + }); + } + } + + handler.on_progress(&ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + files: file_progress, + })); + } + }) +}); +``` + +- [ ] **Step 4: Update the final progress emit after commit completes** + +The emit at ~line 442 also needs per-file data. At this point all files are complete, so emit all files as `FileStatus::Complete`: + +```rust +let final_files: Vec = upload_handles + .iter() + .map(|(repo_path, _)| FileProgress { + filename: repo_path.clone(), + bytes_completed: 0, // exact value doesn't matter, status is Complete + total_bytes: 0, + status: FileStatus::Complete, + }) + .collect(); + +progress::emit( + progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: results.progress.total_bytes_completed, + total_bytes: results.progress.total_bytes, + bytes_per_sec: results.progress.total_bytes_completion_rate, + files: final_files, + }), +); +``` + +- [ ] **Step 5: Verify compilation** + +Run: `cargo build -p huggingface-hub --all-features` +Expected: compiles successfully. + +- [ ] **Step 6: Run fmt and clippy** + +```bash +cargo +nightly fmt +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +- [ ] **Step 7: Commit** + +```bash +git add huggingface_hub/src/xet.rs +git commit -m "feat: emit per-file progress from xet upload polling loop" +``` + +--- + +### Task 3: Render per-file upload progress bars in the CLI + +**Files:** +- Modify: `huggingface_hub/src/bin/hfrs/progress.rs` + +- [ ] **Step 1: Add upload file bar tracking to `ProgressState`** + +Add a field to track the upload file bars and a constant for the max visible bars. In `progress.rs`: + +```rust +const MAX_VISIBLE_UPLOAD_BARS: usize = 10; +``` + +Add to `ProgressState`: + +```rust +struct ProgressState { + files_bar: Option, + bytes_bar: Option, + file_bars: HashMap, + upload_file_bars: HashMap, + last_upload_phase: Option, + spinner: Option, + total_files: usize, +} +``` + +Initialize `upload_file_bars: HashMap::new()` in `CliProgressHandler::new`. + +- [ ] **Step 2: Update `handle_upload` to process per-file progress** + +In the `UploadEvent::Progress` match arm, after the existing aggregate bar update (lines ~224-229), add per-file bar management: + +```rust +if *phase == UploadPhase::Uploading { + // Update aggregate bar + if let Some(ref bar) = state.bytes_bar { + bar.set_length(*total_bytes); + bar.set_position(*bytes_completed); + } + + // Per-file bars + for fp in files { + match fp.status { + FileStatus::Started => { + if !state.upload_file_bars.contains_key(&fp.filename) + && state.upload_file_bars.len() < MAX_VISIBLE_UPLOAD_BARS + { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + state.upload_file_bars.insert(fp.filename.clone(), bar); + } + }, + FileStatus::InProgress => { + // Create bar if we haven't yet and there's room + if !state.upload_file_bars.contains_key(&fp.filename) + && state.upload_file_bars.len() < MAX_VISIBLE_UPLOAD_BARS + { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + state.upload_file_bars.insert(fp.filename.clone(), bar); + } + if let Some(bar) = state.upload_file_bars.get(&fp.filename) { + bar.set_position(fp.bytes_completed); + } + }, + FileStatus::Complete => { + if let Some(bar) = state.upload_file_bars.remove(&fp.filename) { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + if let Some(ref bar) = state.files_bar { + bar.inc(1); + } + }, + } + } +} +``` + +- [ ] **Step 3: Clean up upload file bars on `UploadEvent::Complete`** + +In the `Complete` match arm, add cleanup for upload file bars: + +```rust +UploadEvent::Complete => { + if let Some(spinner) = state.spinner.take() { + spinner.finish_and_clear(); + self.multi.remove(&spinner); + } + if let Some(bar) = state.files_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + if let Some(bar) = state.bytes_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + for (_, bar) in state.upload_file_bars.drain() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } +}, +``` + +- [ ] **Step 4: Remove per-file bars when transitioning away from Uploading phase** + +When the phase changes from `Uploading` to `Committing`, clear all per-file upload bars. In the phase transition block (~line 179), add to the `Committing` arm before creating the spinner: + +```rust +UploadPhase::Committing => { + // Clear per-file upload bars + for (_, bar) in state.upload_file_bars.drain() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + if let Some(ref bar) = state.bytes_bar { + bar.set_position(bar.length().unwrap_or(0)); + bar.finish_and_clear(); + self.multi.remove(bar); + } + state.bytes_bar = None; + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Creating commit..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); +}, +``` + +- [ ] **Step 5: Verify compilation** + +Run: `cargo build -p huggingface-hub --all-features` +Expected: compiles successfully. + +- [ ] **Step 6: Run fmt and clippy** + +```bash +cargo +nightly fmt +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +- [ ] **Step 7: Commit** + +```bash +git add huggingface_hub/src/bin/hfrs/progress.rs +git commit -m "feat: render per-file upload progress bars in CLI (max 10 visible)" +``` + +--- + +### Task 4: Remove aggregate bytes bar for multi-file uploads + +Now that we have per-file bars, showing both per-file bars and an aggregate bytes bar is redundant for multi-file uploads. Keep the aggregate bar only for single-file uploads (where we don't show per-file bars anyway). + +**Files:** +- Modify: `huggingface_hub/src/bin/hfrs/progress.rs` + +- [ ] **Step 1: Conditionally create the aggregate bytes bar** + +In `UploadEvent::Start`, only create the bytes bar when there's a single file: + +```rust +UploadEvent::Start { + total_files, + total_bytes, +} => { + state.total_files = *total_files; + if *total_files > 1 { + let bar = self.multi.add(ProgressBar::new(*total_files as u64)); + bar.set_style(files_style()); + bar.set_message(format!("Upload {} files", total_files)); + state.files_bar = Some(bar); + } + if *total_bytes > 0 && *total_files <= 1 { + let bar = self.multi.add(ProgressBar::new(*total_bytes)); + bar.set_style(bytes_style()); + bar.set_message("Uploading"); + state.bytes_bar = Some(bar); + } +}, +``` + +- [ ] **Step 2: Verify compilation and run tests** + +```bash +cargo build -p huggingface-hub --all-features +cargo test -p huggingface-hub +``` + +- [ ] **Step 3: Run fmt and clippy** + +```bash +cargo +nightly fmt +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +- [ ] **Step 4: Commit** + +```bash +git add huggingface_hub/src/bin/hfrs/progress.rs +git commit -m "feat: show aggregate bytes bar only for single-file uploads" +``` + +--- + +### Task 5: Manual verification + +- [ ] **Step 1: Build the binary in release mode** + +```bash +cargo build -p huggingface-hub --release --all-features +``` + +- [ ] **Step 2: Test single file upload** + +Upload a single file to a test repo. Verify: +- Spinner shows for Preparing and CheckingUploadMode phases +- Single aggregate bytes bar shows during Uploading (no per-file bar) +- Spinner shows for Committing phase + +- [ ] **Step 3: Test multi-file upload** + +Upload 3+ LFS files. Verify: +- Files bar shows "Upload N files" count +- Per-file bars appear during Uploading phase with repo-relative filenames +- Completed file bars disappear, making room for pending files +- No aggregate bytes bar shown +- Phase transitions (Committing spinner) clean up all bars + +- [ ] **Step 4: Test with >10 files** + +Upload 12+ LFS files. Verify: +- Max 10 per-file bars visible at once +- As files complete and bars are removed, new files get bars +- Files bar count increments correctly + +--- + +## Design Decisions + +1. **`files` field is a `Vec`, not `Option`**: Empty vec for phases without per-file data (Preparing, CheckingUploadMode, Committing). Avoids Option nesting. + +2. **Name mapping via `HashMap`**: Maps xet-core's `item_name` (absolute local path for file uploads, `path_in_repo` for bytes uploads) to `path_in_repo`. Comment documents that we mimic xet-core's `std::path::absolute()` logic. + +3. **Regular files skip progress bars**: They're base64-inlined in the NDJSON body — there's no streaming to track. + +4. **Max 10 visible bars**: Completed bars are removed from display immediately, freeing slots for pending files. The files bar serves as the summary line showing overall file count progress. + +5. **Aggregate bytes bar only for single-file uploads**: When per-file bars are shown, the aggregate is redundant. diff --git a/docs/superpowers/plans/2026-04-11-upload-progress-bars.md b/docs/superpowers/plans/2026-04-11-upload-progress-bars.md new file mode 100644 index 0000000..aa7e854 --- /dev/null +++ b/docs/superpowers/plans/2026-04-11-upload-progress-bars.md @@ -0,0 +1,689 @@ +# Upload Progress Bars: Processing vs Transfer + Overflow Aggregation + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Show two summary bars (processing + transfer) during xet uploads and collapse overflow per-file bars into `[+ N files]`, matching the Python `huggingface_hub` CLI. + +**Architecture:** Add `transfer_bytes*` fields to `UploadEvent::Progress` so any library consumer gets both byte streams. Wire them from xet-core's `GroupProgressReport` in the polling loop. Update the CLI renderer to show two summary bars and use a fixed slot pool with overflow aggregation. + +**Tech Stack:** Rust, indicatif, indexmap + +**Spec:** `docs/superpowers/specs/2026-04-11-upload-progress-bars-design.md` + +--- + +### File Structure + +| File | Role | Change | +|------|------|--------| +| `huggingface_hub/src/types/progress.rs` | Library progress types | Add 3 fields to `UploadEvent::Progress`, update all tests | +| `huggingface_hub/src/xet.rs` | Xet upload polling | Populate new transfer fields from `GroupProgressReport` | +| `huggingface_hub/src/api/files.rs` | Non-xet upload emits | Add zeroed transfer fields to 3 emit sites | +| `huggingface_hub/src/bin/hfrs/progress.rs` | CLI renderer | Two summary bars, fixed slot pool, overflow aggregation | +| `huggingface_hub/Cargo.toml` | Dependencies | Add `indexmap` | + +--- + +### Task 1: Add transfer fields to `UploadEvent::Progress` + +**Files:** +- Modify: `huggingface_hub/src/types/progress.rs` + +- [ ] **Step 1: Add the three new fields to `UploadEvent::Progress`** + +In `huggingface_hub/src/types/progress.rs`, change the `Progress` variant from: + +```rust + Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + files: Vec, + }, +``` + +to: + +```rust + Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + transfer_bytes_completed: u64, + transfer_bytes: u64, + transfer_bytes_per_sec: Option, + files: Vec, + }, +``` + +- [ ] **Step 2: Fix all test sites that construct `UploadEvent::Progress`** + +There are 4 tests in `progress.rs` that construct `UploadEvent::Progress`. Add `transfer_bytes_completed: 0, transfer_bytes: 0, transfer_bytes_per_sec: None,` to each: + +1. `emit_records_events` test (~line 161): +```rust + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 512, + total_bytes: 1024, + bytes_per_sec: Some(100.0), + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); +``` + +2. `upload_phase_progression` test (~line 244): +```rust + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: phase.clone(), + bytes_completed: 0, + total_bytes: 100, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); +``` + +3. `upload_progress_with_per_file_data` test (~line 272): +```rust + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 500, + total_bytes: 1000, + bytes_per_sec: Some(100.0), + transfer_bytes_completed: 250, + transfer_bytes: 800, + transfer_bytes_per_sec: Some(50.0), + files: vec![ + FileProgress { + filename: "model/weights.bin".to_string(), + bytes_completed: 300, + total_bytes: 600, + status: FileStatus::InProgress, + }, + FileProgress { + filename: "config.json".to_string(), + bytes_completed: 200, + total_bytes: 400, + status: FileStatus::InProgress, + }, + ], + }), + ); +``` + +Also update the assertion in that test to verify the new fields: +```rust + if let ProgressEvent::Upload(UploadEvent::Progress { + files, + transfer_bytes_completed, + transfer_bytes, + transfer_bytes_per_sec, + .. + }) = &events[0] + { + assert_eq!(files.len(), 2); + assert_eq!(files[0].filename, "model/weights.bin"); + assert_eq!(files[1].filename, "config.json"); + assert_eq!(*transfer_bytes_completed, 250); + assert_eq!(*transfer_bytes, 800); + assert_eq!(*transfer_bytes_per_sec, Some(50.0)); + } else { + panic!("expected Upload(Progress)"); + } +``` + +- [ ] **Step 3: Verify compilation and tests pass** + +Run: +```bash +cargo test -p huggingface-hub --lib -- progress +``` + +Expected: All progress tests pass. + +- [ ] **Step 4: Commit** + +```bash +git add huggingface_hub/src/types/progress.rs +git commit -m "feat: add transfer byte fields to UploadEvent::Progress" +``` + +--- + +### Task 2: Wire transfer fields through emit sites + +**Files:** +- Modify: `huggingface_hub/src/xet.rs` +- Modify: `huggingface_hub/src/api/files.rs` + +- [ ] **Step 1: Update the polling loop emit in `xet.rs`** + +In `huggingface_hub/src/xet.rs`, in the polling loop (~line 577), change: + +```rust + handler.on_progress(&ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + files: file_progress, + })); +``` + +to: + +```rust + handler.on_progress(&ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + transfer_bytes_completed: report.total_transfer_bytes_completed, + transfer_bytes: report.total_transfer_bytes, + transfer_bytes_per_sec: report.total_transfer_bytes_completion_rate, + files: file_progress, + })); +``` + +- [ ] **Step 2: Update the final emit after commit in `xet.rs`** + +In `huggingface_hub/src/xet.rs`, the final emit (~line 606), change: + +```rust + progress::emit( + progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: results.progress.total_bytes_completed, + total_bytes: results.progress.total_bytes, + bytes_per_sec: results.progress.total_bytes_completion_rate, + files: final_files, + }), + ); +``` + +to: + +```rust + progress::emit( + progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: results.progress.total_bytes_completed, + total_bytes: results.progress.total_bytes, + bytes_per_sec: results.progress.total_bytes_completion_rate, + transfer_bytes_completed: results.progress.total_transfer_bytes_completed, + transfer_bytes: results.progress.total_transfer_bytes, + transfer_bytes_per_sec: results.progress.total_transfer_bytes_completion_rate, + files: final_files, + }), + ); +``` + +- [ ] **Step 3: Update the three non-uploading phase emits in `files.rs`** + +In `huggingface_hub/src/api/files.rs`, there are three `UploadEvent::Progress` emits for Preparing (~line 1150), CheckingUploadMode (~line 1164), and Committing (~line 1231). Add zeroed transfer fields to each. For example, the Preparing emit changes from: + +```rust + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Preparing, + bytes_completed: 0, + total_bytes, + bytes_per_sec: None, + files: vec![], + }), + ); +``` + +to: + +```rust + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Preparing, + bytes_completed: 0, + total_bytes, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); +``` + +Apply the same pattern to the CheckingUploadMode and Committing emits. + +- [ ] **Step 4: Verify compilation** + +Run: +```bash +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +Expected: No errors or warnings. + +- [ ] **Step 5: Commit** + +```bash +git add huggingface_hub/src/xet.rs huggingface_hub/src/api/files.rs +git commit -m "feat: populate transfer byte fields from GroupProgressReport" +``` + +--- + +### Task 3: Add `indexmap` dependency and update CLI renderer state + +**Files:** +- Modify: `huggingface_hub/Cargo.toml` +- Modify: `huggingface_hub/src/bin/hfrs/progress.rs` + +- [ ] **Step 1: Add `indexmap` to `huggingface_hub/Cargo.toml`** + +Add `indexmap` in the `[dependencies]` section (alphabetical order): + +```toml +indexmap = "2" +``` + +- [ ] **Step 2: Replace upload state fields in `ProgressState`** + +In `huggingface_hub/src/bin/hfrs/progress.rs`, update the imports at the top: + +```rust +use std::collections::{HashMap, HashSet, VecDeque}; +use std::io::Write; +use std::sync::Mutex; + +use huggingface_hub::{ + DownloadEvent, FileProgress, FileStatus, ProgressEvent, ProgressHandler, UploadEvent, UploadPhase, +}; +use indexmap::IndexMap; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +``` + +Replace `ProgressState` with: + +```rust +struct ProgressState { + // Download state (unchanged) + files_bar: Option, + bytes_bar: Option, + file_bars: HashMap, + download_queue: VecDeque<(String, u64)>, + total_files: usize, + // Upload state (new) + processing_bar: Option, + transfer_bar: Option, + upload_file_slots: Vec>, + upload_active_files: IndexMap, + upload_known_files: HashSet, + upload_completed_files: HashSet, + last_upload_phase: Option, + spinner: Option, + upload_total_files: usize, +} +``` + +Update the `CliProgressHandler::new()` constructor to initialize the new fields: + +```rust + pub fn new(multi: MultiProgress) -> Self { + Self { + multi, + state: Mutex::new(ProgressState { + files_bar: None, + bytes_bar: None, + file_bars: HashMap::new(), + download_queue: VecDeque::new(), + total_files: 0, + processing_bar: None, + transfer_bar: None, + upload_file_slots: Vec::new(), + upload_active_files: IndexMap::new(), + upload_known_files: HashSet::new(), + upload_completed_files: HashSet::new(), + last_upload_phase: None, + spinner: None, + upload_total_files: 0, + }), + } + } +``` + +- [ ] **Step 3: Verify it compiles (will have dead code warnings, that's expected)** + +Run: +```bash +cargo check -p huggingface-hub --features cli +``` + +Expected: Compiles (may have warnings about unused fields — fixed in next task). + +- [ ] **Step 4: Commit** + +```bash +git add huggingface_hub/Cargo.toml huggingface_hub/src/bin/hfrs/progress.rs +git commit -m "refactor: replace upload progress state with fixed slot pool and indexmap" +``` + +--- + +### Task 4: Rewrite `handle_upload` with two summary bars and overflow aggregation + +**Files:** +- Modify: `huggingface_hub/src/bin/hfrs/progress.rs` + +This is the main rendering task. Replace the entire `handle_upload` and `process_upload_file_progress` methods. + +- [ ] **Step 1: Replace `handle_upload` method** + +Replace the existing `handle_upload` method (and remove `process_upload_file_progress`) with: + +```rust + fn handle_upload(&self, event: &UploadEvent) { + let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner()); + match event { + UploadEvent::Start { + total_files, + total_bytes: _, + } => { + state.upload_total_files = *total_files; + if *total_files > 1 { + let bar = self.multi.add(ProgressBar::new(*total_files as u64)); + bar.set_style(files_style()); + bar.set_message(format!("Upload {} files", total_files)); + state.files_bar = Some(bar); + } + }, + UploadEvent::Progress { + phase, + bytes_completed, + total_bytes, + bytes_per_sec, + transfer_bytes_completed, + transfer_bytes, + transfer_bytes_per_sec, + files, + } => { + if state.last_upload_phase.as_ref() != Some(phase) { + if let Some(ref spinner) = state.spinner { + spinner.finish_and_clear(); + self.multi.remove(spinner); + state.spinner = None; + } + match phase { + UploadPhase::Preparing => { + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Preparing files..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + UploadPhase::CheckingUploadMode => { + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Checking upload mode..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + UploadPhase::Uploading => { + // Spinner already cleared above + }, + UploadPhase::Committing => { + self.cleanup_upload_bars(&mut state); + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Creating commit..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + } + state.last_upload_phase = Some(phase.clone()); + } + + if *phase == UploadPhase::Uploading { + // Update or create the processing bar + let completed_count = state.upload_completed_files.len(); + let total_count = state.upload_total_files; + if state.processing_bar.is_none() && *total_bytes > 0 { + let bar = self.multi.add(ProgressBar::new(*total_bytes)); + bar.set_style(bytes_style()); + state.processing_bar = Some(bar); + } + if let Some(ref bar) = state.processing_bar { + bar.set_length(*total_bytes); + bar.set_position(*bytes_completed); + bar.set_message(format!( + "Processing Files ({} / {})", + completed_count, total_count + )); + } + + // Update or create the transfer bar + if state.transfer_bar.is_none() && *transfer_bytes > 0 { + let bar = self.multi.add(ProgressBar::new(*transfer_bytes)); + bar.set_style(bytes_style()); + bar.set_message("New Data Upload"); + state.transfer_bar = Some(bar); + } + if let Some(ref bar) = state.transfer_bar { + bar.set_length(*transfer_bytes); + bar.set_position(*transfer_bytes_completed); + } + + // Update per-file progress + for fp in files { + state.upload_known_files.insert(fp.filename.clone()); + + if fp.bytes_completed == 0 { + continue; + } + + if fp.status == FileStatus::Complete { + state.upload_completed_files.insert(fp.filename.clone()); + } + + state + .upload_active_files + .insert(fp.filename.clone(), fp.clone()); + } + + // Evict completed files from active map when we need room + if state.upload_active_files.len() > MAX_VISIBLE_UPLOAD_BARS { + let completed: Vec = state + .upload_active_files + .keys() + .filter(|k| state.upload_completed_files.contains(*k)) + .cloned() + .collect(); + for name in completed { + state.upload_active_files.swap_remove(&name); + if state.upload_active_files.len() <= MAX_VISIBLE_UPLOAD_BARS { + break; + } + } + } + + // Render the fixed slot pool + self.render_upload_file_slots(&mut state); + } + + // Suppress unused variable warnings for rate fields + let _ = bytes_per_sec; + let _ = transfer_bytes_per_sec; + }, + UploadEvent::FileComplete { .. } => { + // File completion is tracked via FileStatus::Complete in Progress events + }, + UploadEvent::Complete => { + self.cleanup_upload_bars(&mut state); + if let Some(spinner) = state.spinner.take() { + spinner.finish_and_clear(); + self.multi.remove(&spinner); + } + if let Some(bar) = state.files_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + }, + } + } +``` + +- [ ] **Step 2: Add `render_upload_file_slots` method** + +Add this method to `CliProgressHandler`: + +```rust + fn render_upload_file_slots(&self, state: &mut ProgressState) { + let active_count = state.upload_active_files.len(); + let max_individual = if active_count > MAX_VISIBLE_UPLOAD_BARS { + MAX_VISIBLE_UPLOAD_BARS - 1 + } else { + active_count + }; + + // Ensure we have enough slots allocated + while state.upload_file_slots.len() < MAX_VISIBLE_UPLOAD_BARS { + state.upload_file_slots.push(None); + } + + let mut bar_idx = 0; + let mut overflow_bytes_completed: u64 = 0; + let mut overflow_total_bytes: u64 = 0; + let mut overflow_count: usize = 0; + + for (_name, fp) in &state.upload_active_files { + if bar_idx < max_individual { + // Individual file bar + let slot = &mut state.upload_file_slots[bar_idx]; + if let Some(ref bar) = slot { + bar.set_message(truncate_filename(&fp.filename, 40)); + bar.set_length(fp.total_bytes); + bar.set_position(fp.bytes_completed); + } else { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + bar.set_position(fp.bytes_completed); + *slot = Some(bar); + } + } else { + // Overflow: accumulate into last slot + overflow_bytes_completed += fp.bytes_completed; + overflow_total_bytes += fp.total_bytes; + overflow_count += 1; + } + bar_idx += 1; + } + + // Render the overflow slot if needed + if overflow_count > 0 { + let slot_idx = MAX_VISIBLE_UPLOAD_BARS - 1; + let slot = &mut state.upload_file_slots[slot_idx]; + if let Some(ref bar) = slot { + bar.set_message(format!("[+ {} files]", overflow_count)); + bar.set_length(overflow_total_bytes); + bar.set_position(overflow_bytes_completed); + } else { + let bar = self.multi.add(ProgressBar::new(overflow_total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(format!("[+ {} files]", overflow_count)); + bar.set_position(overflow_bytes_completed); + *slot = Some(bar); + } + } + + // Clear any slots beyond what we currently need + let needed_slots = if overflow_count > 0 { + MAX_VISIBLE_UPLOAD_BARS + } else { + active_count + }; + for i in needed_slots..state.upload_file_slots.len() { + if let Some(bar) = state.upload_file_slots[i].take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + } + } +``` + +- [ ] **Step 3: Add `cleanup_upload_bars` method** + +Add this method to `CliProgressHandler`: + +```rust + fn cleanup_upload_bars(&self, state: &mut ProgressState) { + for slot in &mut state.upload_file_slots { + if let Some(bar) = slot.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + } + state.upload_active_files.clear(); + state.upload_known_files.clear(); + state.upload_completed_files.clear(); + if let Some(bar) = state.processing_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + if let Some(bar) = state.transfer_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + } +``` + +- [ ] **Step 4: Remove the old `process_upload_file_progress` method** + +Delete the `process_upload_file_progress` method entirely (it's been replaced by `render_upload_file_slots`). + +- [ ] **Step 5: Verify compilation and lint** + +Run: +```bash +cargo clippy -p huggingface-hub --all-features -- -D warnings +``` + +Expected: No errors or warnings. + +- [ ] **Step 6: Format** + +Run: +```bash +cargo +nightly fmt +``` + +- [ ] **Step 7: Build the release binary** + +Run: +```bash +cargo build -p huggingface-hub --release --features cli +``` + +- [ ] **Step 8: Commit** + +```bash +git add huggingface_hub/src/bin/hfrs/progress.rs +git commit -m "feat: two summary bars (processing + transfer) and overflow [+ N files] aggregation" +``` diff --git a/docs/superpowers/specs/2026-04-05-progress-tracking-design.md b/docs/superpowers/specs/2026-04-05-progress-tracking-design.md new file mode 100644 index 0000000..1936f4c --- /dev/null +++ b/docs/superpowers/specs/2026-04-05-progress-tracking-design.md @@ -0,0 +1,586 @@ +# Progress Tracking for Upload and Download + +**Date:** 2026-04-05 +**Status:** Draft + +## Overview + +Add callback-based progress tracking to the huggingface-hub library's upload and download interfaces. The library emits structured `ProgressEvent` variants to a caller-provided `ProgressHandler` trait object. The first consumer is the hfrs CLI, which renders indicatif progress bars matching the Python `hf` CLI's tqdm style. + +## Goals + +- Per-file byte-level progress for downloads +- Aggregate byte-level progress for uploads (xet provides aggregate, not per-file) +- Distinguish upload phases: preparing, checking upload mode, uploading, committing +- Utilize hf-xet's `GroupProgressReport` and `ItemProgressReport` via polling bridge +- hfrs CLI progress bars visually match the Python `hf` CLI +- Fully additive — no breaking changes to existing API + +## Non-Goals + +- Per-file byte-level upload progress (xet `GroupProgressReport` only provides aggregate) +- Progress for non-transfer operations (repo creation, deletion, listing) +- Persistent progress state or resume tracking +- Custom bar format configuration in the library (that's the consumer's job) + +--- + +## Core Types + +New file: `huggingface_hub/src/types/progress.rs` + +```rust +use std::sync::Arc; + +/// Trait implemented by consumers to receive progress updates. +/// Implementations must be fast — avoid blocking I/O in on_progress(). +pub trait ProgressHandler: Send + Sync { + fn on_progress(&self, event: &ProgressEvent); +} + +/// A clonable, optional handle to a progress handler. +pub type Progress = Option>; + +/// Top-level progress event — either an upload or download event. +#[derive(Debug, Clone)] +pub enum ProgressEvent { + Upload(UploadEvent), + Download(DownloadEvent), +} + +/// Progress events for upload operations. +/// +/// Every variant that represents an in-progress state carries the current +/// `UploadPhase`, so consumers always know the phase from any single event +/// without tracking state across events. +#[derive(Debug, Clone)] +pub enum UploadEvent { + /// Upload operation has started; total file count and bytes are known. + Start { total_files: usize, total_bytes: u64 }, + /// Aggregate byte-level progress (xet/LFS upload). + /// Phase is included so consumers always know the current phase. + Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + }, + /// One or more individual files completed. Batched for efficiency + /// during multi-file uploads (upload_folder). + FileComplete { + files: Vec, + phase: UploadPhase, + }, + /// Entire upload operation finished (all files, commit created). + Complete, +} + +/// Progress events for download operations. +#[derive(Debug, Clone)] +pub enum DownloadEvent { + /// Download operation has started; file count and total bytes known. + Start { total_files: usize, total_bytes: u64 }, + /// Per-file progress update. Only includes files whose state changed + /// since the last event (delta, not full snapshot). Batched for + /// efficiency during multi-file downloads (snapshot_download). + Progress { files: Vec }, + /// Aggregate byte-level progress for xet batch transfers. + /// Separate from per-file Progress because xet provides aggregate + /// stats, not per-file byte counts. + AggregateProgress { + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + }, + /// All downloads finished. + Complete, +} + +/// Per-file progress info, used inside `DownloadEvent::Progress`. +#[derive(Debug, Clone)] +pub struct FileProgress { + pub filename: String, + pub bytes_completed: u64, + pub total_bytes: u64, + pub status: FileStatus, +} + +/// Lifecycle status of a single file within a transfer. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FileStatus { + Started, + InProgress, + Complete, +} + +/// Phases of an upload operation, in order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UploadPhase { + /// Scanning local files and computing sizes. + Preparing, + /// Calling preupload API to classify files as LFS vs regular. + CheckingUploadMode, + /// Transferring file data (xet or inline). + Uploading, + /// Creating the commit on the Hub. + Committing, +} + +/// Emit a progress event if a handler is present. +pub(crate) fn emit(handler: &Progress, event: ProgressEvent) { + if let Some(h) = handler { + h.on_progress(&event); + } +} +``` + +### Design decisions + +- **Nested enum** (`ProgressEvent::Upload(UploadEvent)` / `Download(DownloadEvent)`) — separates upload and download concerns at the type level. Consumers can match on the outer enum to route, then match the inner enum for specifics. +- **Phase on every upload event** (not a separate `PhaseChange` variant) — each in-progress upload event carries the current `UploadPhase`, so consumers always know the phase from any single event. No state tracking needed, no lost-event problems. Phase transitions are detected by observing the phase field change across events. +- **Delta-only download progress** — `DownloadEvent::Progress` only includes files whose state changed since the last event, not a full snapshot of all in-flight files. Keeps event payloads small during large snapshot downloads. +- **Batched file events** — both `DownloadEvent::Progress` and `UploadEvent::FileComplete` carry `Vec` payloads, allowing multiple file updates in a single event. This supports condensed reporting for `upload_folder` and `snapshot_download` without flooding the handler. +- **Separate `AggregateProgress`** for xet downloads — xet provides aggregate byte stats, not per-file byte counts. Keeping this as a distinct variant from per-file `Progress` avoids conflating two different data sources. +- **`Progress` type alias** (`Option>`) keeps params structs clean. +- **`bytes_per_sec` is `Option`** because xet requires >=4 observations before reporting rates. +- **`emit()` helper** avoids `if let Some(h) = &progress { ... }` at every call site. + +--- + +## Library Integration Points + +### Params struct changes + +Add `progress: Progress` with `#[builder(default)]` to these existing structs in `repository.rs`: + +- `RepoUploadFileParams` +- `RepoUploadFolderParams` +- `RepoCreateCommitParams` +- `RepoDownloadFileParams` +- `RepoSnapshotDownloadParams` + +Existing callers are unaffected — `progress` defaults to `None`. + +### Upload flow (`api/files.rs` — `create_commit`) + +Event emission points in the existing flow: + +``` +1. Collect operations, compute total sizes + → emit Upload(Start { total_files, total_bytes }) + +2. Call preupload_and_upload_lfs_files() to classify files + → emit Upload(Progress { phase: CheckingUploadMode, bytes_completed: 0, ... }) + +3. Upload LFS files via xet_upload() + → spawn xet polling task (100ms interval) + → emit Upload(Progress { phase: Uploading, bytes_completed, total_bytes, bytes_per_sec }) + → as xet reports per-item completions: + → emit Upload(FileComplete { files: [...], phase: Uploading }) + +4. POST /api/.../commit + → emit Upload(Progress { phase: Committing, bytes_completed: total_bytes, ... }) + +5. Return CommitInfo + → emit Upload(Complete) +``` + +Phase transitions are implicit — the consumer sees `phase` change from `Preparing` → `CheckingUploadMode` → `Uploading` → `Committing` across successive events. No dedicated phase-change event needed. + +### Xet polling bridge (`xet.rs`) + +Internal helper that polls xet progress and forwards to the handler: + +```rust +fn spawn_xet_upload_poller( + commit: XetUploadCommit, + handler: Arc, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut prev_completed: HashSet = HashSet::new(); + loop { + let report = commit.progress(); + // Emit aggregate byte progress + handler.on_progress(&ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + })); + // Emit FileComplete for newly completed files (delta) + let newly_completed: Vec = report.items.iter() + .filter(|item| item.is_complete && !prev_completed.contains(&item.name)) + .map(|item| item.name.clone()) + .collect(); + if !newly_completed.is_empty() { + prev_completed.extend(newly_completed.iter().cloned()); + handler.on_progress(&ProgressEvent::Upload(UploadEvent::FileComplete { + files: newly_completed, + phase: UploadPhase::Uploading, + })); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) +} +``` + +A similar `spawn_xet_download_poller` emits `DownloadEvent::AggregateProgress` events from `XetFileDownloadGroup::progress()`, plus `DownloadEvent::Progress` with `FileStatus::Complete` entries for newly finished files. + +The polling task is aborted when the transfer completes (the caller holds the `JoinHandle` and calls `.abort()` after `.commit().await` / `.finish().await` returns). + +### Download flow — single file (`api/files.rs`) + +**Non-xet path:** + +``` +1. HEAD request → get ETag, content-length + → emit Download(Start { total_files: 1, total_bytes }) + +2. GET request → wrap response byte stream in ProgressStream adapter + → emit Download(Progress { files: [FileProgress { status: InProgress, ... }] }) + on each chunk (single-element vec) + +3. File written + → emit Download(Progress { files: [FileProgress { status: Complete, ... }] }) + → emit Download(Complete) +``` + +**Xet path:** + +``` +1. HEAD request → detect X-Xet-Hash + → emit Download(Start { total_files: 1, total_bytes }) + +2. xet_download_to_local_dir / xet_download_to_blob + → spawn xet download poller + → emit Download(AggregateProgress { bytes_completed, total_bytes, bytes_per_sec }) + +3. Complete + → emit Download(Progress { files: [FileProgress { status: Complete, ... }] }) + → emit Download(Complete) +``` + +### Download flow — snapshot (`api/files.rs`) + +``` +1. List files via get_paths_info / list_tree + → emit Download(Start { total_files, total_bytes }) + +2. Non-xet files: parallel workers (buffer_unordered, max_workers) + Each worker emits Download(Progress { files: [FileProgress { ... }] }) + with a single-element vec per file per chunk. + On completion: FileProgress { status: Complete } + (handler is Arc, shared safely across tasks) + +3. Xet files: xet_download_batch + → spawn xet download poller + → emit Download(AggregateProgress { ... }) for aggregate bytes + → emit Download(Progress { files: [...] }) with FileStatus::Complete + for newly finished files (delta — only files completed since last poll) + +4. All done → emit Download(Complete) +``` + +### ProgressStream adapter + +A thin stream wrapper for non-xet HTTP downloads that emits per-chunk progress. Uses `pin-project-lite` for pin projection (already a transitive dependency via tokio). + +```rust +pin_project_lite::pin_project! { + struct ProgressStream { + #[pin] + inner: S, + handler: Arc, + filename: String, + total_bytes: u64, + bytes_read: u64, + } +} + +impl>> Stream for ProgressStream { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match this.inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + *this.bytes_read += chunk.len() as u64; + let status = if *this.bytes_read >= *this.total_bytes { + FileStatus::Complete + } else { + FileStatus::InProgress + }; + this.handler.on_progress(&ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: this.filename.clone(), + bytes_completed: *this.bytes_read, + total_bytes: *this.total_bytes, + status, + }], + })); + Poll::Ready(Some(Ok(chunk))) + } + other => other, + } + } +} +``` + +### Xet function signature changes + +All internal xet functions accept an additional `progress: &Progress` parameter: + +- `xet_upload(api, files, repo_id, repo_type, revision, progress)` +- `xet_download_batch(api, repo_id, repo_type, revision, files, progress)` +- `xet_download_to_local_dir(api, repo_id, ..., head_response, progress)` +- `xet_download_to_blob(api, repo_id, ..., file_hash, file_size, path, progress)` + +When `progress` is `None`, no polling task is spawned and the functions behave identically to current code. + +--- + +## hfrs CLI Progress Bars + +### Dependency + +Add to `Cargo.toml`: + +```toml +indicatif = { version = "0.17", optional = true } + +[features] +cli = [ + # ... existing deps ... + "dep:indicatif", +] +``` + +### Visual targets (matching Python `hf` CLI) + +The Python CLI uses standard tqdm formatting. The indicatif bars match this style: + +**Single file download:** +``` +config.json: 100%|██████████████████████████████████| 665/665 [00:00<00:00, 764kB/s] +``` + +**Multi-file download (two simultaneous bars):** +``` +Fetching 23 files: 65%|████████████████ | 15/23 [01:02<00:33, 4.12s/it] +model.safetensors: 45%|████████ | 2.22G/4.93G [01:02<01:16, 36.4MB/s] +``` + +**Multi-file upload (two bars):** +``` +Upload 5 LFS files: 40%|████████ | 2/5 [00:12<00:18, 6.1s/it] +model.safetensors: 100%|███████████████████████████| 4.93G/4.93G [02:15<00:00, 36.4MB/s] +``` + +**Formatting rules:** +- Filenames truncated at 40 chars with `(…)` prefix +- Byte bars use auto-scaled units (kB, MB, GB) +- File-count bars show items/sec +- All bars show `[elapsed, +} + +struct ProgressState { + /// Overall files bar (for multi-file operations) + files_bar: Option, + /// Aggregate bytes bar (for xet transfers or overall bytes tracking) + bytes_bar: Option, + /// Per-file bars for individual file byte progress + file_bars: HashMap, +} +``` + +### Event-to-bar mapping + +| Event | Bar action | +|---|---| +| `Download(Start { total_files > 1, .. })` | Create files bar: `"Fetching {n} files"` | +| `Download(Start { total_files == 1, .. })` | No files bar | +| `Download(Progress { files })` where status = `Started` | Create per-file bytes bar | +| `Download(Progress { files })` where status = `InProgress` | Update per-file bar position | +| `Download(Progress { files })` where status = `Complete` | Finish + remove per-file bar, increment files bar | +| `Download(AggregateProgress { .. })` | Update aggregate bytes bar (xet batch) | +| `Download(Complete)` | Finish all bars | +| `Upload(Start { total_files, .. })` | Create files bar: `"Upload {n} LFS files"` if >1 | +| `Upload(Progress { phase: Uploading, .. })` | Create/update aggregate bytes bar | +| `Upload(Progress { phase: Committing, .. })` | Finish bytes bar, show spinner: `"Creating commit..."` | +| `Upload(FileComplete { files, .. })` | Increment files bar by `files.len()` | +| `Upload(Complete)` | Finish all bars | + +The CLI handler detects phase transitions by comparing the `phase` field against its last-seen phase. When the phase changes (e.g., `Uploading` → `Committing`), it triggers bar transitions (finish bytes bar, start spinner). + +### indicatif style templates + +```rust +// Byte-level bar (matching tqdm default with unit="B", unit_scale=True) +let bytes_style = ProgressStyle::with_template( + "{msg}: {percent}%|{wide_bar}| {bytes}/{total_bytes} [{elapsed}<{eta}, {bytes_per_sec}]" +); + +// File-count bar (matching tqdm default with items) +let files_style = ProgressStyle::with_template( + "{msg}: {percent}%|{wide_bar}| {pos}/{len} [{elapsed}<{eta}, {per_sec}]" +); + +// Spinner for phases without byte progress +let spinner_style = ProgressStyle::with_template("{spinner} {msg}"); +``` + +### CLI wiring + +In `commands/download.rs` and `commands/upload.rs`: + +```rust +let progress: Progress = if args.quiet || env_var_disables_progress() { + None +} else { + Some(Arc::new(CliProgressHandler::new())) +}; +``` + +The `env_var_disables_progress()` check reads `HF_HUB_DISABLE_PROGRESS_BARS` to match Python behavior. + +### Multi-file download bar behavior + +For `snapshot_download` with mixed xet and non-xet files: + +- **Non-xet files**: each gets its own per-file bytes bar (created on `FileStatus::Started`, updated on `InProgress`, removed on `Complete`). Up to `max_workers` (default 8) per-file bars visible simultaneously. Each `Download(Progress)` event carries a single-element `files` vec per chunk. +- **Xet batch files**: a single aggregate bytes bar from `Download(AggregateProgress)` events. Individual completions arrive as `Download(Progress { files })` with `FileStatus::Complete` entries to increment the files counter. +- **Files count bar** stays at the top, counting completions from both paths. + +--- + +## Testing Strategy + +### Unit tests + +**Location:** `huggingface_hub/src/types/progress.rs` (in `#[cfg(test)]` module) + +A `RecordingHandler` captures events for assertions: + +```rust +struct RecordingHandler { + events: Mutex>, +} +impl ProgressHandler for RecordingHandler { + fn on_progress(&self, event: &ProgressEvent) { + self.events.lock().unwrap().push(event.clone()); + } +} +``` + +Tests: +- **Event ordering**: upload emits `Upload(Start) → Upload(Progress { phase: Preparing }) → ... → Upload(Complete)` +- **Phase progression**: phases advance monotonically across events (Preparing → CheckingUploadMode → Uploading → Committing) +- **Download file lifecycle**: `Download(Progress { status: Started }) → Download(Progress { status: InProgress, increasing bytes }) → Download(Progress { status: Complete })` +- **Delta-only delivery**: download progress events only contain files that changed +- **Batched FileComplete**: upload FileComplete can carry multiple filenames +- **None handler is no-op**: `progress: None` doesn't panic or change behavior +- **Handler is Send + Sync**: compile-time check that `Arc` satisfies bounds + +### ProgressStream adapter tests + +```rust +#[tokio::test] +async fn test_progress_stream_emits_events() { + // Create synthetic byte stream, wrap in ProgressStream, + // consume all chunks, assert DownloadFileProgress events + // have monotonically increasing bytes_completed +} +``` + +### Integration tests + +**Location:** `huggingface_hub/tests/integration_test.rs` + +Gated on `HF_TOKEN` (and `HF_TEST_WRITE=1` for upload tests): + +- **Download with progress**: download a known small file, attach `RecordingHandler`, verify `Download(Start)`, at least one `Download(Progress)` with `InProgress`, and a final `Download(Progress)` with `Complete` + correct filename +- **Snapshot download with progress**: download 2-3 files, verify `total_files` matches, per-file progress events for each filename, all `Complete` before `Download(Complete)` +- **Upload with progress** (`HF_TEST_WRITE=1`): upload a small file, verify `Upload(Start)` through `Upload(Complete)` with phase progression across events + +### CLI manual verification + +No automated CLI rendering tests. Manual steps: + +1. `hfrs download gpt2 config.json` — one bytes bar appears and completes +2. `hfrs download gpt2` — files count bar + per-file bytes bars +3. `hfrs download gpt2 --quiet` — no bars, only path printed +4. `hfrs upload ./small-file.txt` — phase progression visible +5. `hfrs upload ./test-dir/` — files bar + bytes bar +6. `HF_HUB_DISABLE_PROGRESS_BARS=1 hfrs download gpt2 config.json` — bars suppressed +7. Large xet-backed repo — aggregate bytes bar with rate display + +--- + +## Module Structure + +### New files + +| File | Purpose | +|---|---| +| `huggingface_hub/src/types/progress.rs` | Core types: `ProgressEvent`, `UploadPhase`, `ProgressHandler`, `Progress`, `emit()` | +| `huggingface_hub/src/bin/hfrs/progress.rs` | `CliProgressHandler` with indicatif | + +### Modified files + +| File | Changes | +|---|---| +| `huggingface_hub/src/types/mod.rs` | Add `pub mod progress;` and re-export. `lib.rs` already does `pub use types::*` so progress types are automatically publicly re-exported — no changes needed in `lib.rs`. | +| `huggingface_hub/src/repository.rs` | Add `progress: Progress` to 5 params structs | +| `huggingface_hub/src/api/files.rs` | Emit events in `create_commit`, `download_file`, `snapshot_download`; add `ProgressStream` | +| `huggingface_hub/src/xet.rs` | Accept `progress: &Progress`, add xet polling helpers | +| `huggingface_hub/Cargo.toml` | Add `indicatif` optional dep under `cli` feature | +| `huggingface_hub/src/bin/hfrs/commands/download.rs` | Create handler, pass to params | +| `huggingface_hub/src/bin/hfrs/commands/upload.rs` | Create handler, pass to params | +| `huggingface_hub/src/bin/hfrs/main.rs` | Add `mod progress;` | +| `huggingface_hub/tests/integration_test.rs` | Add progress-tracking test variants | + +**Note:** The blocking API (`HFClientSync`, `HFRepositorySync`) is generated by `sync_api!` macros in `huggingface_hub/src/macros.rs`. Since this design passes `progress` as a field on the existing params structs (which are taken by `&` reference), the macro-generated blocking wrappers forward it transparently — no changes to `macros.rs` or `blocking.rs` are needed. + +### Methods not covered + +The following methods are **not** in scope for this design but could be added later: + +- `download_file_stream` — returns a raw byte stream `(Option, Box>)` where the consumer controls consumption. Progress could be added via a `ProgressStream` wrapper, but the caller already has direct access to chunk-level data. +- `download_file_to_bytes` — thin wrapper around `download_file_stream` that collects to `Bytes`. Would inherit progress if `download_file_stream` gained it. + +### Public API additions + +All additive, no breaking changes: + +```rust +pub trait ProgressHandler: Send + Sync { + fn on_progress(&self, event: &ProgressEvent); +} +pub enum ProgressEvent { Upload(UploadEvent), Download(DownloadEvent) } +pub enum UploadEvent { Start, Progress, FileComplete, Complete } +pub enum DownloadEvent { Start, Progress, AggregateProgress, Complete } +pub struct FileProgress { filename, bytes_completed, total_bytes, status } +pub enum FileStatus { Started, InProgress, Complete } +pub enum UploadPhase { Preparing, CheckingUploadMode, Uploading, Committing } +pub type Progress = Option>; +``` + +### What does NOT change + +- Return types of all existing methods +- Behavior when `progress: None` (the default) +- Blocking wrapper (`HFClientSync`, etc.) — blocking methods are generated by `sync_api!` macros in `macros.rs` which forward params by reference; `progress` passes through as part of the params struct with no macro changes needed +- Feature flags — no new flag; `indicatif` is under existing `cli` feature diff --git a/docs/superpowers/specs/2026-04-11-upload-progress-bars-design.md b/docs/superpowers/specs/2026-04-11-upload-progress-bars-design.md new file mode 100644 index 0000000..37a3b7c --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-upload-progress-bars-design.md @@ -0,0 +1,119 @@ +# Upload Progress Bars: Processing vs Transfer + Overflow Aggregation + +**Date:** 2026-04-11 +**Branch:** `assaf/progress` +**Status:** Design approved + +## Problem + +The Rust `hfrs` CLI shows a single aggregate bytes bar during xet uploads, but xet-core's `GroupProgressReport` actually tracks two distinct byte streams: + +1. **Processing bytes** (`total_bytes` / `total_bytes_completed`) — dedup/chunking work +2. **Transfer bytes** (`total_transfer_bytes` / `total_transfer_bytes_completed`) — actual network upload + +The Python `huggingface_hub` library shows both as separate bars via `XetProgressReporter`. Our Rust implementation currently ignores the transfer fields entirely. + +Additionally, when more than 10 files are being uploaded, overflow files are silently queued with no visual indication. Python collapses overflow into a `[+ N files]` aggregate bar in the last visible slot. + +## Goals + +- Expose processing and transfer byte progress through the library's `UploadEvent::Progress` type so any consumer can use both +- Render two summary bars in the CLI matching Python: "Processing Files (M / N)" and "New Data Upload" +- Show `[+ N files]` aggregate bar when active files exceed `MAX_VISIBLE_UPLOAD_BARS` (10) + +## Non-Goals + +- Changing download progress (unrelated) +- Changing `FileProgress`, `FileStatus`, or `ProgressHandler` types +- Notebook/console environment detection (Rust is CLI-only) + +## Design + +### 1. Library types — `UploadEvent::Progress` + +Add three fields to the `Progress` variant in `types/progress.rs`: + +```rust +Progress { + phase: UploadPhase, + // Processing/dedup bytes (existing) + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + // Actual network transfer bytes (new) + transfer_bytes_completed: u64, + transfer_bytes: u64, + transfer_bytes_per_sec: Option, + // Per-file progress (existing) + files: Vec, +} +``` + +The existing `bytes_completed` / `total_bytes` fields retain their names and map to xet-core's processing/dedup counters. The new `transfer_bytes*` fields map to `GroupProgressReport::total_transfer_bytes*`. + +### 2. Emit layer — `xet.rs` + `api/files.rs` + +**Polling loop (`xet.rs`):** Read both `total_bytes*` and `total_transfer_bytes*` from `GroupProgressReport` and pass through to the `Progress` event. + +**Final emit after commit (`xet.rs`):** Same — populate both sets of fields from `results.progress`. + +**Non-uploading phases (`files.rs`):** Emit `0` for all transfer fields (Preparing, CheckingUploadMode, Committing phases have no transfer activity). + +### 3. CLI renderer — `progress.rs` + +#### Two summary bars + +Replace the single `bytes_bar` with two bars during upload: + +- **`processing_bar`**: Style matches Python's "Processing Files (M / N)" — shows `bytes_completed / total_bytes` with `bytes_per_sec` rate. The description updates with completed file count vs total file count. +- **`transfer_bar`**: "New Data Upload" — shows `transfer_bytes_completed / transfer_bytes` with `transfer_bytes_per_sec` rate. + +Both bars are shown for all uploads (single-file and multi-file). They are created during the first `Uploading` phase progress event (not during `Start`, since transfer totals aren't known yet). + +#### Overflow aggregation + +Replace the current dynamic create/remove bar approach with a fixed slot pool: + +- `MAX_VISIBLE_UPLOAD_BARS` (10) bar slots, initially empty +- Each tick, iterate through active (non-complete) files in insertion order +- If active files <= 10, each gets its own slot with description and position updated in-place +- If active files > 10, first 9 get individual bars (slots 0-8), and slot 9 becomes `[+ N files]` showing combined bytes across all overflow files +- Completed files are removed from the active state only when `active_count > MAX_VISIBLE_UPLOAD_BARS`, to make room for incoming files (matches Python's eviction logic) +- Bar slots are reused by overwriting description/total/position rather than creating and removing bars from `MultiProgress` + +#### State changes + +`ProgressState` changes: +- Remove: `bytes_bar`, `upload_file_bars: HashMap`, `upload_completed_files: HashSet` +- Add: `processing_bar: Option`, `transfer_bar: Option` +- Add: `upload_file_slots: Vec>` (fixed size 10) +- Add: `upload_active_files: OrderedMap` (insertion-ordered, like Python's `OrderedDict`) +- Add: `upload_known_files: HashSet`, `upload_completed_files: HashSet` (for file counting) + +Note: Rust's `IndexMap` crate (or `BTreeMap` with insertion index) provides ordered map semantics. Since `indexmap` is already used transitively, it's a natural choice. + +#### Cleanup + +- `Committing` phase transition: close and remove both summary bars and all file slot bars +- `Complete` event: same cleanup as safety net + +### 4. Files changed + +| File | Change | +|------|--------| +| `huggingface_hub/src/types/progress.rs` | Add 3 fields to `UploadEvent::Progress`, update tests | +| `huggingface_hub/src/xet.rs` | Polling loop + final emit read both byte streams | +| `huggingface_hub/src/api/files.rs` | Non-uploading phase emits include transfer fields (zeroed) | +| `huggingface_hub/src/bin/hfrs/progress.rs` | Two summary bars, fixed slot pool, overflow aggregation | +| `huggingface_hub/Cargo.toml` | Add `indexmap` dependency if not already present | + +### 5. Testing + +- Unit tests in `types/progress.rs` updated for new fields +- Manual test: upload a folder with >10 files and verify: + - Two summary bars appear ("Processing Files" and "New Data Upload") + - Per-file bars show for first 9 active files + - 10th slot shows `[+ N files]` with aggregate bytes + - Completed files are evicted to make room for new ones + - Both summary bars finish cleanly on commit +- Manual test: single-file upload shows both summary bars diff --git a/huggingface_hub/Cargo.toml b/huggingface_hub/Cargo.toml index 31d9f81..506ea28 100644 --- a/huggingface_hub/Cargo.toml +++ b/huggingface_hub/Cargo.toml @@ -33,6 +33,7 @@ sha2 = { version = "0.10", optional = true } # cli deps clap = { version = "4", features = ["derive", "env", "color"], optional = true } +indicatif = { version = "0.17", optional = true } owo-colors = { version = "4", optional = true } comfy-table = { version = "7", optional = true } anyhow = { version = "1", optional = true } @@ -50,6 +51,7 @@ cli = [ "tokio/macros", "tokio/rt-multi-thread", "dep:clap", + "dep:indicatif", "dep:owo-colors", "dep:comfy-table", "dep:anyhow", @@ -106,6 +108,10 @@ required-features = ["spaces"] name = "download_upload" path = "examples/download_upload.rs" +[[example]] +name = "progress" +path = "examples/progress.rs" + [[example]] name = "blocking_read" path = "examples/blocking_read.rs" diff --git a/huggingface_hub/examples/progress.rs b/huggingface_hub/examples/progress.rs new file mode 100644 index 0000000..9163b5b --- /dev/null +++ b/huggingface_hub/examples/progress.rs @@ -0,0 +1,76 @@ +//! Download a file with progress tracking. +//! +//! Demonstrates implementing the `ProgressHandler` trait to receive +//! real-time progress callbacks during file transfers. +//! +//! Run: cargo run -p huggingface-hub --example progress + +use std::sync::Arc; + +use huggingface_hub::{DownloadEvent, FileStatus, HFClient, ProgressEvent, ProgressHandler, RepoDownloadFileParams}; + +struct PrintProgressHandler; + +impl ProgressHandler for PrintProgressHandler { + fn on_progress(&self, event: &ProgressEvent) { + match event { + ProgressEvent::Download(dl) => match dl { + DownloadEvent::Start { + total_files, + total_bytes, + } => { + println!("Starting download: {total_files} file(s), {total_bytes} bytes"); + }, + DownloadEvent::Progress { files } => { + for f in files { + let pct = if f.total_bytes > 0 { + f.bytes_completed * 100 / f.total_bytes + } else { + 0 + }; + let status = match f.status { + FileStatus::Started => "started", + FileStatus::InProgress => "downloading", + FileStatus::Complete => "complete", + }; + println!(" {}: {pct}% ({}/{}) [{status}]", f.filename, f.bytes_completed, f.total_bytes); + } + }, + DownloadEvent::AggregateProgress { + bytes_completed, + total_bytes, + .. + } => { + println!(" aggregate: {bytes_completed}/{total_bytes}"); + }, + DownloadEvent::Complete => { + println!("Download complete."); + }, + }, + ProgressEvent::Upload(ul) => { + println!("Upload event: {ul:?}"); + }, + } + } +} + +#[tokio::main] +async fn main() -> huggingface_hub::Result<()> { + let api = HFClient::new()?; + let model = api.model("openai-community", "gpt2"); + + let tmp_dir = tempfile::tempdir().expect("failed to create tempdir"); + + let path = model + .download_file( + &RepoDownloadFileParams::builder() + .filename("config.json") + .local_dir(tmp_dir.path().to_path_buf()) + .progress(Some(Arc::new(PrintProgressHandler))) + .build(), + ) + .await?; + + println!("File saved to: {}", path.display()); + Ok(()) +} diff --git a/huggingface_hub/src/api/files.rs b/huggingface_hub/src/api/files.rs index 983b42f..11f7a7c 100644 --- a/huggingface_hub/src/api/files.rs +++ b/huggingface_hub/src/api/files.rs @@ -12,6 +12,9 @@ use url::Url; use crate::error::{HFError, Result}; use crate::repository::HFRepository; +use crate::types::progress::{ + self, DownloadEvent, FileProgress, FileStatus, Progress, ProgressEvent, UploadEvent, UploadPhase, +}; use crate::types::{ AddSource, CommitInfo, CommitOperation, RepoCreateCommitParams, RepoDeleteFileParams, RepoDeleteFolderParams, RepoDownloadFileParams, RepoDownloadFileStreamParams, RepoDownloadFileToBytesParams, RepoGetPathsInfoParams, @@ -106,6 +109,21 @@ impl HFRepository { /// /// Endpoint: GET {endpoint}/{prefix}{repo_id}/resolve/{revision}/{filename} pub async fn download_file(&self, params: &RepoDownloadFileParams) -> Result { + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Start { + total_files: 1, + total_bytes: 0, + }), + ); + let result = self.download_file_inner(params).await; + if result.is_ok() { + progress::emit(¶ms.progress, ProgressEvent::Download(DownloadEvent::Complete)); + } + result + } + + async fn download_file_inner(&self, params: &RepoDownloadFileParams) -> Result { if params.local_dir.is_some() { self.download_file_to_local_dir(params).await } else { @@ -256,6 +274,7 @@ impl HFRepository { ) .await?; + let file_size = extract_file_size(&head_response).unwrap_or(0); let has_xet_hash = head_response.headers().get(constants::HEADER_X_XET_HASH).is_some(); #[cfg(feature = "xet")] @@ -263,7 +282,7 @@ impl HFRepository { if has_xet_hash { let local_dir = params.local_dir.as_ref().unwrap(); return self - .xet_download_to_local_dir(revision, ¶ms.filename, local_dir, &head_response) + .xet_download_to_local_dir(revision, ¶ms.filename, local_dir, &head_response, ¶ms.progress) .await; } } @@ -301,7 +320,25 @@ impl HFRepository { tokio::fs::create_dir_all(parent).await?; } - stream_response_to_file(response, &dest_path).await?; + stream_response_to_file_with_progress( + response, + &dest_path, + ¶ms.progress, + Some(¶ms.filename), + file_size, + ) + .await?; + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: params.filename.clone(), + bytes_completed: file_size, + total_bytes: file_size, + status: FileStatus::Complete, + }], + }), + ); Ok(dest_path) } @@ -451,7 +488,6 @@ impl HFRepository { let commit_hash = extract_commit_hash(&head_response); let xet_hash = extract_xet_hash(&head_response); let has_xet_hash = xet_hash.is_some(); - #[cfg(feature = "xet")] let file_size: u64 = extract_file_size(&head_response).unwrap_or_else(|| { tracing::warn!(url = %url, "missing or invalid Content-Length/X-Linked-Size header, defaulting file size to 0"); 0 @@ -482,7 +518,8 @@ impl HFRepository { } let _lock = cache::acquire_lock(cache_dir, repo_folder, &etag).await?; - self.xet_download_to_blob(revision, &xet_hash, file_size, &blob).await?; + self.xet_download_to_blob(revision, ¶ms.filename, &xet_hash, file_size, &blob, ¶ms.progress) + .await?; } return finalize_cached_file(cache_dir, repo_folder, revision, &commit_hash, ¶ms.filename, &etag).await; @@ -500,6 +537,17 @@ impl HFRepository { let blob = cache::blob_path(cache_dir, repo_folder, &etag); if blob.exists() && !force_download { + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: params.filename.clone(), + bytes_completed: file_size, + total_bytes: file_size, + status: FileStatus::Complete, + }], + }), + ); return finalize_cached_file(cache_dir, repo_folder, revision, &commit_hash, ¶ms.filename, &etag).await; } @@ -516,7 +564,25 @@ impl HFRepository { .headers(self.hf_client.auth_headers()) .send() .await?; - stream_response_to_file(response, &incomplete_path).await?; + stream_response_to_file_with_progress( + response, + &incomplete_path, + ¶ms.progress, + Some(¶ms.filename), + file_size, + ) + .await?; + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: params.filename.clone(), + bytes_completed: file_size, + total_bytes: file_size, + status: FileStatus::Complete, + }], + }), + ); tokio::fs::rename(&incomplete_path, &blob).await?; finalize_cached_file(cache_dir, repo_folder, revision, &commit_hash, ¶ms.filename, &etag).await @@ -603,10 +669,40 @@ impl HFRepository { .list_filtered_files(&commit_hash, params.allow_patterns.as_ref(), params.ignore_patterns.as_ref()) .await?; + let total_files = filenames.len(); let force = params.force_download == Some(true); + let mut cached_filenames = Vec::new(); if !force && params.local_dir.is_none() { - filenames.retain(|f| !cache::snapshot_path(cache_dir, &repo_folder, &commit_hash, f).exists()); + filenames.retain(|f| { + if cache::snapshot_path(cache_dir, &repo_folder, &commit_hash, f).exists() { + cached_filenames.push(f.clone()); + false + } else { + true + } + }); + } + + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Start { + total_files, + total_bytes: 0, + }), + ); + for f in &cached_filenames { + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: f.clone(), + bytes_completed: 0, + total_bytes: 0, + status: FileStatus::Complete, + }], + }), + ); } #[cfg(feature = "xet")] @@ -682,9 +778,11 @@ impl HFRepository { let mut non_xet_filenames = Vec::new(); if let Some(ref local_dir) = params.local_dir { + let mut local_cached = Vec::new(); for meta in file_metas { let dest = local_dir.join(&meta.filename); if dest.exists() && !force { + local_cached.push(meta.filename); continue; } if meta.xet_hash.is_some() { @@ -693,6 +791,19 @@ impl HFRepository { non_xet_filenames.push(meta.filename); } } + for f in &local_cached { + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: f.clone(), + bytes_completed: 0, + total_bytes: 0, + status: FileStatus::Complete, + }], + }), + ); + } let xet_batch_fut = async { if xet_metas.is_empty() { @@ -704,9 +815,11 @@ impl HFRepository { hash: m.xet_hash.as_ref().unwrap().clone(), file_size: m.file_size, path: local_dir.join(&m.filename), + filename: m.filename.clone(), }) .collect(); - self.xet_download_batch(&commit_hash, &batch_files).await + self.xet_download_batch(&commit_hash, &batch_files, ¶ms.progress).await?; + Ok(()) }; let non_xet_dl_params = build_download_params( @@ -716,6 +829,7 @@ impl HFRepository { &commit_hash, params.force_download, Some(local_dir.clone()), + ¶ms.progress, ); let non_xet_fut = async { download_concurrently(self, &non_xet_dl_params, max_workers).await?; @@ -723,6 +837,7 @@ impl HFRepository { }; tokio::try_join!(xet_batch_fut, non_xet_fut)?; + progress::emit(¶ms.progress, ProgressEvent::Download(DownloadEvent::Complete)); return Ok(local_dir.clone()); } @@ -738,6 +853,17 @@ impl HFRepository { &meta.etag, ) .await?; + progress::emit( + ¶ms.progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: meta.filename.clone(), + bytes_completed: meta.file_size, + total_bytes: meta.file_size, + status: FileStatus::Complete, + }], + }), + ); continue; } if meta.xet_hash.is_some() { @@ -761,9 +887,10 @@ impl HFRepository { hash: m.xet_hash.as_ref().unwrap().clone(), file_size: m.file_size, path: cache::blob_path(cache_dir, &repo_folder, &m.etag), + filename: m.filename.clone(), }) .collect(); - self.xet_download_batch(&commit_hash, &batch_files).await?; + self.xet_download_batch(&commit_hash, &batch_files, ¶ms.progress).await?; for m in &xet_metas { cache::create_pointer_symlink(cache_dir, &repo_folder, &m.commit_hash, &m.filename, &m.etag) .await?; @@ -779,6 +906,7 @@ impl HFRepository { &commit_hash, params.force_download, None, + ¶ms.progress, ); let non_xet_fut = async { download_concurrently(self, &non_xet_dl_params, max_workers).await?; @@ -799,8 +927,10 @@ impl HFRepository { &commit_hash, params.force_download, Some(local_dir.clone()), + ¶ms.progress, ); download_concurrently(self, &dl_params, max_workers).await?; + progress::emit(¶ms.progress, ProgressEvent::Download(DownloadEvent::Complete)); return Ok(local_dir.clone()); } @@ -811,6 +941,7 @@ impl HFRepository { &commit_hash, params.force_download, None, + ¶ms.progress, ); download_concurrently(self, &dl_params, max_workers).await?; } @@ -819,6 +950,7 @@ impl HFRepository { cache::write_ref(cache_dir, &repo_folder, revision, &commit_hash).await?; } + progress::emit(¶ms.progress, ProgressEvent::Download(DownloadEvent::Complete)); Ok(cache_dir.join(&repo_folder).join("snapshots").join(&commit_hash)) } } @@ -904,6 +1036,7 @@ fn build_download_params( commit_hash: &str, force_download: Option, local_dir: Option, + progress: &Progress, ) -> Vec { filenames .iter() @@ -913,6 +1046,7 @@ fn build_download_params( revision: Some(commit_hash.to_string()), force_download, local_files_only: None, + progress: progress.clone(), }) .collect() } @@ -922,18 +1056,49 @@ async fn download_concurrently( params: &[RepoDownloadFileParams], max_workers: usize, ) -> Result> { - futures::stream::iter(params.iter().map(|p| api.download_file(p))) + futures::stream::iter(params.iter().map(|p| api.download_file_inner(p))) .buffer_unordered(max_workers) .try_collect() .await } -async fn stream_response_to_file(response: reqwest::Response, dest: &Path) -> Result<()> { +async fn stream_response_to_file_with_progress( + response: reqwest::Response, + dest: &Path, + handler: &Progress, + filename: Option<&str>, + total_bytes: u64, +) -> Result<()> { let mut file = tokio::fs::File::create(dest).await?; let mut stream = response.bytes_stream(); + let mut bytes_read: u64 = 0; + + if let (Some(h), Some(fname)) = (handler, filename) { + h.on_progress(&ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: fname.to_string(), + bytes_completed: 0, + total_bytes, + status: FileStatus::Started, + }], + })); + } + while let Some(chunk) = stream.next().await { let chunk = chunk?; file.write_all(&chunk).await?; + bytes_read += chunk.len() as u64; + + if let (Some(h), Some(fname)) = (handler, filename) { + h.on_progress(&ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: fname.to_string(), + bytes_completed: bytes_read, + total_bytes, + status: FileStatus::InProgress, + }], + })); + } } file.flush().await?; Ok(()) @@ -955,9 +1120,61 @@ impl HFRepository { let revision = params.revision.as_deref().unwrap_or(constants::DEFAULT_REVISION); let url = format!("{}/commit/{}", self.hf_client.api_url(Some(self.repo_type), &self.repo_path()), revision); + let add_ops_count = params + .operations + .iter() + .filter(|op| matches!(op, CommitOperation::Add { .. })) + .count(); + let total_bytes: u64 = { + let mut total = 0u64; + for op in ¶ms.operations { + if let CommitOperation::Add { source, .. } = op { + total += match source { + AddSource::Bytes(b) => b.len() as u64, + AddSource::File(p) => tokio::fs::metadata(p).await.map(|m| m.len()).unwrap_or(0), + }; + } + } + total + }; + + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Start { + total_files: add_ops_count, + total_bytes, + }), + ); + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Preparing, + bytes_completed: 0, + total_bytes, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); + // Determine which files should be uploaded via xet (LFS) vs. inline // (regular). Files uploaded via xet are referenced by their SHA256 OID // in the commit NDJSON. + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::CheckingUploadMode, + bytes_completed: 0, + total_bytes, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); let lfs_uploaded: HashMap = self.preupload_and_upload_lfs_files(params, revision).await?; @@ -1015,6 +1232,20 @@ impl HFRepository { }) .collect(); + progress::emit( + ¶ms.progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Committing, + bytes_completed: total_bytes, + total_bytes, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); + let mut headers = self.hf_client.auth_headers(); headers.insert(reqwest::header::CONTENT_TYPE, "application/x-ndjson".parse().unwrap()); @@ -1030,6 +1261,8 @@ impl HFRepository { .hf_client .check_response(response, Some(&repo_path), crate::error::NotFoundContext::Repo) .await?; + + progress::emit(¶ms.progress, ProgressEvent::Upload(UploadEvent::Complete)); Ok(response.json().await?) } @@ -1067,6 +1300,7 @@ impl HFRepository { revision: params.revision.clone(), create_pr: params.create_pr, parent_commit: params.parent_commit.clone(), + progress: params.progress.clone(), }) .await } @@ -1116,6 +1350,7 @@ impl HFRepository { revision: params.revision.clone(), create_pr: params.create_pr, parent_commit: None, + progress: params.progress.clone(), }) .await } @@ -1136,6 +1371,7 @@ impl HFRepository { revision: params.revision.clone(), create_pr: params.create_pr, parent_commit: None, + progress: None, }) .await } @@ -1180,6 +1416,7 @@ impl HFRepository { revision: Some(revision.to_string()), create_pr: params.create_pr, parent_commit: None, + progress: None, }) .await } @@ -1334,7 +1571,7 @@ impl HFRepository { /// Compute SHA256, negotiate LFS batch transfer, and upload via xet. async fn upload_lfs_files_via_xet( &self, - _params: &RepoCreateCommitParams, + params: &RepoCreateCommitParams, revision: &str, lfs_files: &[&(String, u64, Vec, &AddSource)], ) -> Result> { @@ -1371,7 +1608,7 @@ impl HFRepository { .map(|(path, _, _, source)| (path.clone(), (*source).clone())) .collect(); - self.xet_upload(&xet_files, revision).await?; + self.xet_upload(&xet_files, revision, ¶ms.progress).await?; let result: HashMap = lfs_with_sha .into_iter() diff --git a/huggingface_hub/src/bin/hfrs/cli.rs b/huggingface_hub/src/bin/hfrs/cli.rs index a5fda90..18c8c7d 100644 --- a/huggingface_hub/src/bin/hfrs/cli.rs +++ b/huggingface_hub/src/bin/hfrs/cli.rs @@ -25,6 +25,10 @@ pub struct Cli { #[arg(long, global = true)] pub no_color: bool, + /// Disable progress bars + #[arg(long, global = true)] + pub disable_progress_bars: bool, + #[command(subcommand)] pub command: Command, } diff --git a/huggingface_hub/src/bin/hfrs/commands/download.rs b/huggingface_hub/src/bin/hfrs/commands/download.rs index f6310a9..b5a332c 100644 --- a/huggingface_hub/src/bin/hfrs/commands/download.rs +++ b/huggingface_hub/src/bin/hfrs/commands/download.rs @@ -1,11 +1,13 @@ use std::path::PathBuf; +use std::sync::Arc; use anyhow::Result; use clap::Args as ClapArgs; -use huggingface_hub::{HFClient, RepoDownloadFileParams, RepoSnapshotDownloadParams}; +use huggingface_hub::{HFClient, Progress, RepoDownloadFileParams, RepoSnapshotDownloadParams}; use crate::cli::RepoTypeArg; use crate::output::CommandResult; +use crate::progress::CliProgressHandler; /// Download files from the Hub #[derive(ClapArgs)] @@ -49,10 +51,18 @@ pub struct Args { pub quiet: bool, } -pub async fn execute(api: &HFClient, args: Args) -> Result { +pub async fn execute(api: &HFClient, args: Args, multi: Option) -> Result { let repo_type: huggingface_hub::RepoType = args.r#type.into(); let repo = crate::util::make_repo(api, &args.repo_id, repo_type); + let handler: Progress = if args.quiet { + None + } else if let Some(multi) = multi { + Some(Arc::new(CliProgressHandler::new(multi))) + } else { + None + }; + let path = if args.filenames.len() == 1 && args.include.is_empty() && args.exclude.is_empty() { let params = RepoDownloadFileParams { filename: args.filenames.into_iter().next().unwrap(), @@ -60,6 +70,7 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { revision: args.revision, force_download: if args.force_download { Some(true) } else { None }, local_files_only: None, + progress: handler.clone(), }; repo.download_file(¶ms).await? } else { @@ -83,6 +94,7 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { force_download: if args.force_download { Some(true) } else { None }, local_files_only: None, max_workers: None, + progress: handler.clone(), }; repo.snapshot_download(¶ms).await? }; diff --git a/huggingface_hub/src/bin/hfrs/commands/repos/delete_files.rs b/huggingface_hub/src/bin/hfrs/commands/repos/delete_files.rs index b57fe87..9c76b9b 100644 --- a/huggingface_hub/src/bin/hfrs/commands/repos/delete_files.rs +++ b/huggingface_hub/src/bin/hfrs/commands/repos/delete_files.rs @@ -51,6 +51,7 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { revision: args.revision, create_pr: if args.create_pr { Some(true) } else { None }, parent_commit: None, + progress: None, }; let result = repo.create_commit(¶ms).await?; let url = result.commit_url.or(result.pr_url).unwrap_or_default(); diff --git a/huggingface_hub/src/bin/hfrs/commands/upload.rs b/huggingface_hub/src/bin/hfrs/commands/upload.rs index d5b8afa..0d1f9c8 100644 --- a/huggingface_hub/src/bin/hfrs/commands/upload.rs +++ b/huggingface_hub/src/bin/hfrs/commands/upload.rs @@ -1,12 +1,14 @@ use std::path::PathBuf; +use std::sync::Arc; use anyhow::{Result, bail}; use clap::Args as ClapArgs; -use huggingface_hub::{AddSource, CreateRepoParams, HFClient, RepoUploadFileParams, RepoUploadFolderParams}; +use huggingface_hub::{AddSource, CreateRepoParams, HFClient, Progress, RepoUploadFileParams, RepoUploadFolderParams}; use tracing::info; use crate::cli::RepoTypeArg; use crate::output::CommandResult; +use crate::progress::CliProgressHandler; /// Upload files to the Hub #[derive(ClapArgs)] @@ -61,7 +63,7 @@ pub struct Args { pub quiet: bool, } -pub async fn execute(api: &HFClient, args: Args) -> Result { +pub async fn execute(api: &HFClient, args: Args, multi: Option) -> Result { let repo_type: huggingface_hub::RepoType = args.r#type.into(); let local_path = args.local_path.unwrap_or_else(|| PathBuf::from(".")); let repo = crate::util::make_repo(api, &args.repo_id, repo_type); @@ -79,6 +81,14 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { api.create_repo(&create_params).await?; } + let handler: Progress = if args.quiet { + None + } else if let Some(multi) = multi { + Some(Arc::new(CliProgressHandler::new(multi))) + } else { + None + }; + let commit_info = if local_path.is_file() { let path_in_repo = args.path_in_repo.unwrap_or_else(|| { local_path @@ -94,6 +104,7 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { commit_description: args.commit_description, create_pr: if args.create_pr { Some(true) } else { None }, parent_commit: None, + progress: handler.clone(), }; repo.upload_file(¶ms).await? } else if local_path.is_dir() { @@ -122,6 +133,7 @@ pub async fn execute(api: &HFClient, args: Args) -> Result { allow_patterns, ignore_patterns, delete_patterns, + progress: handler.clone(), }; repo.upload_folder(¶ms).await? } else { diff --git a/huggingface_hub/src/bin/hfrs/main.rs b/huggingface_hub/src/bin/hfrs/main.rs index eed3d5c..3274793 100644 --- a/huggingface_hub/src/bin/hfrs/main.rs +++ b/huggingface_hub/src/bin/hfrs/main.rs @@ -1,6 +1,7 @@ mod cli; mod commands; mod output; +mod progress; mod util; use std::io::IsTerminal; @@ -19,7 +20,14 @@ async fn main() -> ExitCode { let cli = Cli::parse(); let color = should_use_color(cli.no_color); - init_logging(color); + let progress_disabled = + cli.disable_progress_bars || progress::progress_disabled_by_env() || !std::io::stderr().is_terminal(); + let multi = if progress_disabled { + None + } else { + Some(indicatif::MultiProgress::with_draw_target(indicatif::ProgressDrawTarget::stderr_with_hz(10))) + }; + init_logging(color, multi.as_ref()); let mut builder = HFClientBuilder::new(); if let Some(t) = cli.token { @@ -50,11 +58,11 @@ async fn main() -> ExitCode { Command::Auth(args) => commands::auth::execute(&api, args).await, Command::Cache(args) => commands::cache::execute(&api, args).await, Command::Datasets(args) => commands::datasets::execute(&api, args).await, - Command::Download(args) => commands::download::execute(&api, args).await, + Command::Download(args) => commands::download::execute(&api, args, multi.clone()).await, Command::Models(args) => commands::models::execute(&api, args).await, Command::Repos(args) => commands::repos::execute(&api, args).await, Command::Spaces(args) => commands::spaces::execute(&api, args).await, - Command::Upload(args) => commands::upload::execute(&api, args).await, + Command::Upload(args) => commands::upload::execute(&api, args, multi.clone()).await, Command::Env(args) => commands::env::execute(args).await, Command::Version(args) => commands::version::execute(args).await, }; @@ -244,7 +252,7 @@ fn format_hf_error(err: &HFError) -> String { const XET_CRATES: &[&str] = &["hf_xet", "xet_client", "xet_core_structures", "xet_data", "xet_runtime"]; -fn init_logging(color: bool) { +fn init_logging(color: bool, multi: Option<&indicatif::MultiProgress>) { let mut filter_str = if let Ok(level) = std::env::var("HF_LOG_LEVEL") { level } else if std::env::var("HF_DEBUG").is_ok() { @@ -264,10 +272,20 @@ fn init_logging(color: bool) { EnvFilter::new("off") }); - tracing_subscriber::fmt() - .with_env_filter(filter) - .with_target(true) - .with_ansi(color) - .with_writer(std::io::stderr) - .init(); + if let Some(multi) = multi { + let writer = progress::MultiProgressWriter::new(multi.clone()); + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(true) + .with_ansi(color) + .with_writer(move || writer.clone()) + .init(); + } else { + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_target(true) + .with_ansi(color) + .with_writer(std::io::stderr) + .init(); + } } diff --git a/huggingface_hub/src/bin/hfrs/progress.rs b/huggingface_hub/src/bin/hfrs/progress.rs new file mode 100644 index 0000000..a71582a --- /dev/null +++ b/huggingface_hub/src/bin/hfrs/progress.rs @@ -0,0 +1,405 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::io::Write; +use std::sync::Mutex; + +use huggingface_hub::{ + DownloadEvent, FileProgress, FileStatus, ProgressEvent, ProgressHandler, UploadEvent, UploadPhase, +}; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; + +/// Renders indicatif progress bars in the terminal for download and upload operations. +const MAX_VISIBLE_FILE_BARS: usize = 10; +const MAX_VISIBLE_UPLOAD_BARS: usize = 10; + +pub struct CliProgressHandler { + multi: MultiProgress, + state: Mutex, +} + +struct ProgressState { + // Download state + files_bar: Option, + bytes_bar: Option, + file_bars: HashMap, + download_queue: VecDeque<(String, u64)>, + total_files: usize, + // Upload state + processing_bar: Option, + transfer_bar: Option, + upload_file_bars: HashMap, + upload_queue: VecDeque<(String, u64)>, + upload_completed_files: HashSet, + last_upload_phase: Option, + spinner: Option, + upload_total_files: usize, +} + +fn bytes_style() -> ProgressStyle { + ProgressStyle::with_template( + "{msg}: {percent}%|{wide_bar:.cyan/blue}| {bytes}/{total_bytes} [{elapsed}<{eta}, {bytes_per_sec}]", + ) + .expect("hardcoded template") + .progress_chars("##-") +} + +fn files_style() -> ProgressStyle { + ProgressStyle::with_template("{msg}: {percent}%|{wide_bar:.green/blue}| {pos}/{len} [{elapsed}<{eta}]") + .expect("hardcoded template") + .progress_chars("##-") +} + +fn spinner_style() -> ProgressStyle { + ProgressStyle::with_template("{spinner:.green} {msg}").expect("hardcoded template") +} + +fn truncate_filename(name: &str, max_len: usize) -> String { + if name.len() <= max_len { + return name.to_string(); + } + let suffix = &name[name.len() - (max_len - 1)..]; + format!("…{suffix}") +} + +impl CliProgressHandler { + pub fn new(multi: MultiProgress) -> Self { + Self { + multi, + state: Mutex::new(ProgressState { + files_bar: None, + bytes_bar: None, + file_bars: HashMap::new(), + download_queue: VecDeque::new(), + total_files: 0, + processing_bar: None, + transfer_bar: None, + upload_file_bars: HashMap::new(), + upload_queue: VecDeque::new(), + upload_completed_files: HashSet::new(), + last_upload_phase: None, + spinner: None, + upload_total_files: 0, + }), + } + } + + fn handle_download(&self, event: &DownloadEvent) { + let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner()); + match event { + DownloadEvent::Start { + total_files, + total_bytes, + } => { + state.total_files = *total_files; + if *total_files > 1 { + let bar = self.multi.add(ProgressBar::new(*total_files as u64)); + bar.set_style(files_style()); + bar.set_message(format!("Fetching {} files", total_files)); + state.files_bar = Some(bar); + } + if *total_bytes > 0 && *total_files == 1 { + let bar = self.multi.add(ProgressBar::new(*total_bytes)); + bar.set_style(bytes_style()); + bar.set_message("Downloading"); + state.bytes_bar = Some(bar); + } + }, + DownloadEvent::Progress { files } => { + for fp in files { + match fp.status { + FileStatus::Started => { + if state.total_files == 1 && state.bytes_bar.is_none() && fp.total_bytes > 0 { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message("Downloading"); + state.bytes_bar = Some(bar); + } else if state.file_bars.len() < MAX_VISIBLE_FILE_BARS { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + state.file_bars.insert(fp.filename.clone(), bar); + } else { + state.download_queue.push_back((fp.filename.clone(), fp.total_bytes)); + } + }, + FileStatus::InProgress => { + if let Some(bar) = state.file_bars.get(&fp.filename) { + bar.set_position(fp.bytes_completed); + } else if state.file_bars.len() < MAX_VISIBLE_FILE_BARS { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + bar.set_position(fp.bytes_completed); + state.file_bars.insert(fp.filename.clone(), bar); + state.download_queue.retain(|(n, _)| n != &fp.filename); + } else if let Some(ref bar) = state.bytes_bar { + bar.set_position(fp.bytes_completed); + } + }, + FileStatus::Complete => { + if let Some(bar) = state.file_bars.remove(&fp.filename) { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + state.download_queue.retain(|(n, _)| n != &fp.filename); + if let Some(ref bar) = state.bytes_bar { + bar.set_position(fp.bytes_completed); + } + if let Some(ref bar) = state.files_bar { + bar.inc(1); + } + while state.file_bars.len() < MAX_VISIBLE_FILE_BARS { + if let Some((name, total)) = state.download_queue.pop_front() { + let bar = self.multi.add(ProgressBar::new(total)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&name, 40)); + state.file_bars.insert(name, bar); + } else { + break; + } + } + }, + } + } + }, + DownloadEvent::AggregateProgress { + bytes_completed, + total_bytes, + .. + } => { + if state.bytes_bar.is_none() { + let bar = self.multi.add(ProgressBar::new(*total_bytes)); + bar.set_style(bytes_style()); + bar.set_message("Downloading"); + state.bytes_bar = Some(bar); + } + if let Some(ref bar) = state.bytes_bar { + bar.set_length(*total_bytes); + bar.set_position(*bytes_completed); + } + }, + DownloadEvent::Complete => { + if let Some(ref bar) = state.files_bar { + bar.finish_and_clear(); + } + if let Some(ref bar) = state.bytes_bar { + bar.finish_and_clear(); + } + for (_, bar) in state.file_bars.drain() { + bar.finish_and_clear(); + } + state.download_queue.clear(); + }, + } + } + + fn handle_upload(&self, event: &UploadEvent) { + let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner()); + match event { + UploadEvent::Start { + total_files, + total_bytes: _, + } => { + state.upload_total_files = *total_files; + }, + UploadEvent::Progress { + phase, + bytes_completed, + total_bytes, + transfer_bytes_completed, + transfer_bytes, + files, + .. + } => { + if state.last_upload_phase.as_ref() != Some(phase) { + if let Some(ref spinner) = state.spinner { + spinner.finish_and_clear(); + self.multi.remove(spinner); + state.spinner = None; + } + match phase { + UploadPhase::Preparing => { + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Preparing files..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + UploadPhase::CheckingUploadMode => { + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Checking upload mode..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + UploadPhase::Uploading => { + let pbar = self.multi.add(ProgressBar::new(0)); + pbar.set_style(bytes_style()); + pbar.set_message(format!("Processing Files (0 / {})", state.upload_total_files)); + state.processing_bar = Some(pbar); + + let tbar = self.multi.add(ProgressBar::new(0)); + tbar.set_style(bytes_style()); + tbar.set_message("New Data Upload"); + state.transfer_bar = Some(tbar); + }, + UploadPhase::Committing => { + self.cleanup_upload_bars(&mut state); + let bar = self.multi.add(ProgressBar::new_spinner()); + bar.set_style(spinner_style()); + bar.set_message("Creating commit..."); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + state.spinner = Some(bar); + }, + } + state.last_upload_phase = Some(phase.clone()); + } + + if *phase == UploadPhase::Uploading { + let completed_count = state.upload_completed_files.len(); + let total_count = state.upload_total_files; + if let Some(ref bar) = state.processing_bar { + bar.set_length(*total_bytes); + bar.set_position(*bytes_completed); + bar.set_message(format!("Processing Files ({} / {})", completed_count, total_count)); + } + + if let Some(ref bar) = state.transfer_bar { + bar.set_length(*transfer_bytes); + bar.set_position(*transfer_bytes_completed); + } + + for fp in files { + self.process_upload_file_progress(&mut state, fp); + } + } + }, + UploadEvent::FileComplete { .. } => {}, + UploadEvent::Complete => { + self.cleanup_upload_bars(&mut state); + if let Some(spinner) = state.spinner.take() { + spinner.finish_and_clear(); + self.multi.remove(&spinner); + } + }, + } + } + + fn process_upload_file_progress(&self, state: &mut ProgressState, fp: &FileProgress) { + if state.upload_completed_files.contains(&fp.filename) { + return; + } + match fp.status { + FileStatus::Started => { + if !state.upload_file_bars.contains_key(&fp.filename) { + if state.upload_file_bars.len() < MAX_VISIBLE_UPLOAD_BARS { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + state.upload_file_bars.insert(fp.filename.clone(), bar); + } else { + state.upload_queue.push_back((fp.filename.clone(), fp.total_bytes)); + } + } + }, + FileStatus::InProgress => { + if let Some(bar) = state.upload_file_bars.get(&fp.filename) { + bar.set_position(fp.bytes_completed); + } else if state.upload_file_bars.len() < MAX_VISIBLE_UPLOAD_BARS { + let bar = self.multi.add(ProgressBar::new(fp.total_bytes)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&fp.filename, 40)); + bar.set_position(fp.bytes_completed); + state.upload_file_bars.insert(fp.filename.clone(), bar); + state.upload_queue.retain(|(n, _)| n != &fp.filename); + } + }, + FileStatus::Complete => { + if state.upload_completed_files.insert(fp.filename.clone()) { + if let Some(bar) = state.upload_file_bars.remove(&fp.filename) { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + state.upload_queue.retain(|(n, _)| n != &fp.filename); + if let Some(ref bar) = state.files_bar { + bar.inc(1); + } + while state.upload_file_bars.len() < MAX_VISIBLE_UPLOAD_BARS { + if let Some((name, total)) = state.upload_queue.pop_front() { + let bar = self.multi.add(ProgressBar::new(total)); + bar.set_style(bytes_style()); + bar.set_message(truncate_filename(&name, 40)); + state.upload_file_bars.insert(name, bar); + } else { + break; + } + } + } + }, + } + } + + fn cleanup_upload_bars(&self, state: &mut ProgressState) { + for (_, bar) in state.upload_file_bars.drain() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + state.upload_queue.clear(); + state.upload_completed_files.clear(); + if let Some(bar) = state.processing_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + if let Some(bar) = state.transfer_bar.take() { + bar.finish_and_clear(); + self.multi.remove(&bar); + } + } +} + +impl ProgressHandler for CliProgressHandler { + fn on_progress(&self, event: &ProgressEvent) { + match event { + ProgressEvent::Download(dl) => self.handle_download(dl), + ProgressEvent::Upload(ul) => self.handle_upload(ul), + } + } +} + +pub fn progress_disabled_by_env() -> bool { + std::env::var("HF_HUB_DISABLE_PROGRESS_BARS").is_ok_and(|v| v == "1" || v.eq_ignore_ascii_case("true")) +} + +/// An `io::Write` adapter that routes output through `MultiProgress::println()`, +/// ensuring log lines appear above progress bars without visual corruption. +#[derive(Clone)] +pub struct MultiProgressWriter { + multi: MultiProgress, + buf: Vec, +} + +impl MultiProgressWriter { + pub fn new(multi: MultiProgress) -> Self { + Self { multi, buf: Vec::new() } + } +} + +impl Write for MultiProgressWriter { + fn write(&mut self, data: &[u8]) -> std::io::Result { + self.buf.extend_from_slice(data); + while let Some(pos) = self.buf.iter().position(|&b| b == b'\n') { + let line = String::from_utf8_lossy(&self.buf[..pos]).into_owned(); + self.multi.println(&line).map_err(std::io::Error::other)?; + self.buf.drain(..=pos); + } + Ok(data.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + if !self.buf.is_empty() { + let line = String::from_utf8_lossy(&self.buf).into_owned(); + self.multi.println(&line).map_err(std::io::Error::other)?; + self.buf.clear(); + } + Ok(()) + } +} diff --git a/huggingface_hub/src/types/mod.rs b/huggingface_hub/src/types/mod.rs index 9d57962..61bdeba 100644 --- a/huggingface_hub/src/types/mod.rs +++ b/huggingface_hub/src/types/mod.rs @@ -1,6 +1,7 @@ pub mod cache; pub mod commit; pub mod params; +pub mod progress; pub mod repo; pub mod repo_params; pub mod user; @@ -10,6 +11,7 @@ pub mod spaces; pub use commit::*; pub use params::*; +pub use progress::*; pub use repo::*; pub use repo_params::*; #[cfg(feature = "spaces")] diff --git a/huggingface_hub/src/types/progress.rs b/huggingface_hub/src/types/progress.rs new file mode 100644 index 0000000..3656e5e --- /dev/null +++ b/huggingface_hub/src/types/progress.rs @@ -0,0 +1,348 @@ +use std::sync::Arc; + +/// Trait implemented by consumers to receive progress updates. +/// Implementations must be fast — avoid blocking I/O in on_progress(). +pub trait ProgressHandler: Send + Sync { + /// Called by the library each time progress changes. + /// Receives a reference to avoid allocation; clone if you need to store it. + fn on_progress(&self, event: &ProgressEvent); +} + +/// A clonable, optional handle to a progress handler. +pub type Progress = Option>; + +/// Top-level progress event — either an upload or download event. +#[derive(Debug, Clone)] +pub enum ProgressEvent { + Upload(UploadEvent), + Download(DownloadEvent), +} + +/// Progress events for upload operations. +/// +/// Every variant that represents an in-progress state carries the current +/// `UploadPhase`, so consumers always know the phase from any single event +/// without tracking state across events. +#[derive(Debug, Clone)] +pub enum UploadEvent { + /// Upload operation has started; total file count and bytes are known. + Start { total_files: usize, total_bytes: u64 }, + /// Byte-level progress during xet/LFS upload. + /// `files` contains per-file progress for xet uploads (may be empty + /// for phases without per-file granularity). + Progress { + phase: UploadPhase, + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + transfer_bytes_completed: u64, + transfer_bytes: u64, + transfer_bytes_per_sec: Option, + files: Vec, + }, + /// One or more individual files completed. Batched for efficiency + /// during multi-file uploads (upload_folder). + FileComplete { files: Vec, phase: UploadPhase }, + /// Entire upload operation finished (all files, commit created). + Complete, +} + +/// Progress events for download operations. +#[derive(Debug, Clone)] +pub enum DownloadEvent { + /// Download operation has started; file count and total bytes known. + Start { total_files: usize, total_bytes: u64 }, + /// Per-file progress update. Only includes files whose state changed + /// since the last event (delta, not full snapshot). Batched for + /// efficiency during multi-file downloads (snapshot_download). + Progress { files: Vec }, + /// Aggregate byte-level progress for xet batch transfers. + /// Separate from per-file Progress because xet provides aggregate + /// stats, not per-file byte counts. + AggregateProgress { + bytes_completed: u64, + total_bytes: u64, + bytes_per_sec: Option, + }, + /// All downloads finished. + Complete, +} + +/// Per-file progress info, used inside [`DownloadEvent::Progress`]. +#[derive(Debug, Clone)] +pub struct FileProgress { + pub filename: String, + pub bytes_completed: u64, + pub total_bytes: u64, + pub status: FileStatus, +} + +/// Lifecycle status of a single file within a transfer. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FileStatus { + /// File transfer has been queued but no bytes received yet. + Started, + /// Bytes are actively being transferred. + InProgress, + /// All bytes have been received and the file is written to disk. + Complete, +} + +/// Phases of an upload operation, in order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum UploadPhase { + /// Scanning local files and computing sizes. + Preparing, + /// Calling preupload API to classify files as LFS vs regular. + CheckingUploadMode, + /// Transferring file data (xet or inline). + Uploading, + /// Creating the commit on the Hub. + Committing, +} + +pub(crate) fn emit(handler: &Progress, event: ProgressEvent) { + if let Some(h) = handler { + h.on_progress(&event); + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use super::*; + + struct RecordingHandler { + events: Mutex>, + } + + impl RecordingHandler { + fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + + fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + } + + impl ProgressHandler for RecordingHandler { + fn on_progress(&self, event: &ProgressEvent) { + self.events.lock().unwrap().push(event.clone()); + } + } + + #[test] + fn handler_is_send_sync() { + fn assert_send_sync() {} + assert_send_sync::>(); + } + + #[test] + fn emit_with_none_is_noop() { + let progress: Progress = None; + emit(&progress, ProgressEvent::Download(DownloadEvent::Complete)); + } + + #[test] + fn emit_records_events() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Start { + total_files: 2, + total_bytes: 1024, + }), + ); + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 512, + total_bytes: 1024, + bytes_per_sec: Some(100.0), + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); + emit(&progress, ProgressEvent::Upload(UploadEvent::Complete)); + + let events = handler.events(); + assert_eq!(events.len(), 3); + assert!(matches!(events[0], ProgressEvent::Upload(UploadEvent::Start { .. }))); + assert!(matches!(events[1], ProgressEvent::Upload(UploadEvent::Progress { .. }))); + assert!(matches!(events[2], ProgressEvent::Upload(UploadEvent::Complete))); + } + + #[test] + fn download_file_lifecycle() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + emit( + &progress, + ProgressEvent::Download(DownloadEvent::Start { + total_files: 1, + total_bytes: 1000, + }), + ); + emit( + &progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: "file.bin".to_string(), + bytes_completed: 0, + total_bytes: 1000, + status: FileStatus::Started, + }], + }), + ); + emit( + &progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: "file.bin".to_string(), + bytes_completed: 500, + total_bytes: 1000, + status: FileStatus::InProgress, + }], + }), + ); + emit( + &progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: "file.bin".to_string(), + bytes_completed: 1000, + total_bytes: 1000, + status: FileStatus::Complete, + }], + }), + ); + emit(&progress, ProgressEvent::Download(DownloadEvent::Complete)); + + let events = handler.events(); + assert_eq!(events.len(), 5); + } + + #[test] + fn upload_phase_progression() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + let phases = [ + UploadPhase::Preparing, + UploadPhase::CheckingUploadMode, + UploadPhase::Uploading, + UploadPhase::Committing, + ]; + + for phase in &phases { + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: phase.clone(), + bytes_completed: 0, + total_bytes: 100, + bytes_per_sec: None, + transfer_bytes_completed: 0, + transfer_bytes: 0, + transfer_bytes_per_sec: None, + files: vec![], + }), + ); + } + + let events = handler.events(); + assert_eq!(events.len(), 4); + for (i, phase) in phases.iter().enumerate() { + if let ProgressEvent::Upload(UploadEvent::Progress { phase: p, .. }) = &events[i] { + assert_eq!(p, phase); + } else { + panic!("expected Upload(Progress) at index {i}"); + } + } + } + + #[test] + fn upload_progress_with_per_file_data() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + emit( + &progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: 500, + total_bytes: 1000, + bytes_per_sec: Some(100.0), + transfer_bytes_completed: 250, + transfer_bytes: 800, + transfer_bytes_per_sec: Some(50.0), + files: vec![ + FileProgress { + filename: "model/weights.bin".to_string(), + bytes_completed: 300, + total_bytes: 600, + status: FileStatus::InProgress, + }, + FileProgress { + filename: "config.json".to_string(), + bytes_completed: 200, + total_bytes: 400, + status: FileStatus::InProgress, + }, + ], + }), + ); + + let events = handler.events(); + assert_eq!(events.len(), 1); + if let ProgressEvent::Upload(UploadEvent::Progress { + files, + transfer_bytes_completed, + transfer_bytes, + transfer_bytes_per_sec, + .. + }) = &events[0] + { + assert_eq!(files.len(), 2); + assert_eq!(files[0].filename, "model/weights.bin"); + assert_eq!(files[1].filename, "config.json"); + assert_eq!(*transfer_bytes_completed, 250); + assert_eq!(*transfer_bytes, 800); + assert_eq!(*transfer_bytes_per_sec, Some(50.0)); + } else { + panic!("expected Upload(Progress)"); + } + } + + #[test] + fn batched_file_complete() { + let handler = Arc::new(RecordingHandler::new()); + let progress: Progress = Some(handler.clone()); + + emit( + &progress, + ProgressEvent::Upload(UploadEvent::FileComplete { + files: vec!["a.bin".to_string(), "b.bin".to_string(), "c.bin".to_string()], + phase: UploadPhase::Uploading, + }), + ); + + let events = handler.events(); + assert_eq!(events.len(), 1); + if let ProgressEvent::Upload(UploadEvent::FileComplete { files, .. }) = &events[0] { + assert_eq!(files.len(), 3); + } else { + panic!("expected FileComplete"); + } + } +} diff --git a/huggingface_hub/src/types/repo_params.rs b/huggingface_hub/src/types/repo_params.rs index 3a63672..d56507e 100644 --- a/huggingface_hub/src/types/repo_params.rs +++ b/huggingface_hub/src/types/repo_params.rs @@ -4,6 +4,7 @@ use serde::Serialize; use typed_builder::TypedBuilder; use super::commit::{AddSource, CommitOperation}; +use super::progress::Progress; use super::repo::{GatedApprovalMode, GatedNotificationsMode}; #[derive(Default, TypedBuilder)] @@ -79,6 +80,9 @@ pub struct RepoDownloadFileParams { /// If `true`, only return the file if it is already cached locally; never make a network request. #[builder(default, setter(strip_option))] pub local_files_only: Option, + /// Optional progress handler for tracking download progress. + #[builder(default)] + pub progress: Progress, } #[derive(TypedBuilder)] @@ -120,6 +124,9 @@ pub struct RepoSnapshotDownloadParams { /// Maximum number of concurrent file downloads. #[builder(default, setter(strip_option))] pub max_workers: Option, + /// Optional progress handler for tracking download progress. + #[builder(default)] + pub progress: Progress, } #[derive(TypedBuilder)] @@ -144,6 +151,9 @@ pub struct RepoUploadFileParams { /// Expected parent commit SHA. The upload fails if the branch head has moved past this commit. #[builder(default, setter(into, strip_option))] pub parent_commit: Option, + /// Optional progress handler for tracking upload progress. + #[builder(default)] + pub progress: Progress, } #[derive(TypedBuilder)] @@ -175,6 +185,9 @@ pub struct RepoUploadFolderParams { /// Glob patterns for remote files to delete that are not present locally. #[builder(default, setter(strip_option))] pub delete_patterns: Option>, + /// Optional progress handler for tracking upload progress. + #[builder(default)] + pub progress: Progress, } #[derive(TypedBuilder)] @@ -228,6 +241,9 @@ pub struct RepoCreateCommitParams { /// Expected parent commit SHA. The commit fails if the branch head has moved past this commit. #[builder(default, setter(into, strip_option))] pub parent_commit: Option, + /// Optional progress handler for tracking upload progress. + #[builder(default)] + pub progress: Progress, } #[derive(Default, TypedBuilder)] diff --git a/huggingface_hub/src/xet.rs b/huggingface_hub/src/xet.rs index 618c986..6fccab5 100644 --- a/huggingface_hub/src/xet.rs +++ b/huggingface_hub/src/xet.rs @@ -4,15 +4,21 @@ //! When xet headers are detected during download/upload but the feature //! is not enabled, HFError::XetNotEnabled is returned at the call site. +use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use serde::Deserialize; -use xet::xet_session::{Sha256Policy, XetFileInfo, XetFileMetadata}; +use xet::xet_session::{Sha256Policy, XetFileDownload, XetFileInfo, XetFileMetadata, XetFileUpload}; use crate::client::HFClient; use crate::constants; use crate::error::{HFError, Result}; use crate::repository::HFRepository; +use crate::types::progress::{ + self, DownloadEvent, FileProgress, FileStatus, Progress, ProgressEvent, UploadEvent, UploadPhase, +}; use crate::types::{AddSource, GetXetTokenParams, RepoType}; #[derive(Debug, Deserialize)] @@ -86,10 +92,109 @@ fn is_session_poisoned(err: &xet::error::XetError) -> bool { ) } +pub(crate) struct TrackedDownload { + pub handle: XetFileDownload, + pub filename: String, + pub file_size: u64, + pub complete_emitted: AtomicBool, +} + +fn emit_remaining_completes(progress: &Progress, tracked: &[TrackedDownload]) { + for t in tracked { + if !t.complete_emitted.swap(true, Ordering::Relaxed) { + progress::emit( + progress, + ProgressEvent::Download(DownloadEvent::Progress { + files: vec![FileProgress { + filename: t.filename.clone(), + bytes_completed: t.file_size, + total_bytes: t.file_size, + status: FileStatus::Complete, + }], + }), + ); + } + } +} + +fn spawn_download_progress_poller( + progress: &Progress, + group: &xet::xet_session::XetFileDownloadGroup, + tracked: Arc>, +) -> Option> { + let handler = progress.as_ref()?.clone(); + let group = group.clone(); + Some(tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let report = group.progress(); + handler.on_progress(&ProgressEvent::Download(DownloadEvent::AggregateProgress { + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + })); + + let mut completed = Vec::new(); + let mut in_progress = Vec::new(); + + for t in tracked.iter() { + if t.complete_emitted.load(Ordering::Relaxed) { + continue; + } + if t.handle.result().is_some() { + if !t.complete_emitted.swap(true, Ordering::Relaxed) { + completed.push(FileProgress { + filename: t.filename.clone(), + bytes_completed: t.file_size, + total_bytes: t.file_size, + status: FileStatus::Complete, + }); + } + continue; + } + if let Some(item) = t.handle.progress() { + let total = item.total_bytes.max(t.file_size); + if item.bytes_completed >= total && total > 0 { + if !t.complete_emitted.swap(true, Ordering::Relaxed) { + completed.push(FileProgress { + filename: t.filename.clone(), + bytes_completed: total, + total_bytes: total, + status: FileStatus::Complete, + }); + } + } else { + let status = if item.bytes_completed > 0 { + FileStatus::InProgress + } else { + FileStatus::Started + }; + in_progress.push(FileProgress { + filename: t.filename.clone(), + bytes_completed: item.bytes_completed, + total_bytes: total, + status, + }); + } + } + } + + if !completed.is_empty() { + handler.on_progress(&ProgressEvent::Download(DownloadEvent::Progress { files: completed })); + } + if !in_progress.is_empty() { + handler.on_progress(&ProgressEvent::Download(DownloadEvent::Progress { files: in_progress })); + } + } + })) +} + pub(crate) struct XetBatchFile { pub hash: String, pub file_size: u64, pub path: PathBuf, + pub filename: String, } impl HFRepository { @@ -99,6 +204,7 @@ impl HFRepository { filename: &str, local_dir: &std::path::Path, head_response: &reqwest::Response, + progress: &Progress, ) -> Result { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); @@ -139,15 +245,25 @@ impl HFRepository { let file_info = XetFileInfo::new(file_hash, file_size); - group + let handle = group .download_file_to_path(file_info, dest_path.clone()) .await .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; - group - .finish() - .await - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + let tracked = Arc::new(vec![TrackedDownload { + handle, + filename: filename.to_string(), + file_size, + complete_emitted: AtomicBool::new(false), + }]); + let poll_handle = spawn_download_progress_poller(progress, &group, Arc::clone(&tracked)); + + let result = group.finish().await; + if let Some(h) = poll_handle { + h.abort(); + } + result.map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + emit_remaining_completes(progress, &tracked); Ok(dest_path) } @@ -155,9 +271,11 @@ impl HFRepository { pub(crate) async fn xet_download_to_blob( &self, revision: &str, + filename: &str, file_hash: &str, file_size: u64, path: &std::path::Path, + progress: &Progress, ) -> Result<()> { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); @@ -193,21 +311,36 @@ impl HFRepository { let file_info = XetFileInfo::new(file_hash.to_string(), file_size); - group + let handle = group .download_file_to_path(file_info, incomplete_path.clone()) .await .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; - group - .finish() - .await - .map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + let tracked = Arc::new(vec![TrackedDownload { + handle, + filename: filename.to_string(), + file_size, + complete_emitted: AtomicBool::new(false), + }]); + let poll_handle = spawn_download_progress_poller(progress, &group, Arc::clone(&tracked)); + + let result = group.finish().await; + if let Some(h) = poll_handle { + h.abort(); + } + result.map_err(|e| HFError::Other(format!("Xet download failed: {e}")))?; + emit_remaining_completes(progress, &tracked); tokio::fs::rename(&incomplete_path, path).await?; Ok(()) } - pub(crate) async fn xet_download_batch(&self, revision: &str, files: &[XetBatchFile]) -> Result<()> { + pub(crate) async fn xet_download_batch( + &self, + revision: &str, + files: &[XetBatchFile], + progress: &Progress, + ) -> Result<()> { if files.is_empty() { return Ok(()); } @@ -238,6 +371,7 @@ impl HFRepository { .await .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; + let mut tracked_vec = Vec::with_capacity(files.len()); let mut incomplete_paths = Vec::with_capacity(files.len()); for file in files { if let Some(parent) = file.path.parent() { @@ -248,18 +382,29 @@ impl HFRepository { let file_info = XetFileInfo::new(file.hash.clone(), file.file_size); - group + let handle = group .download_file_to_path(file_info, incomplete.clone()) .await .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; + tracked_vec.push(TrackedDownload { + handle, + filename: file.filename.clone(), + file_size: file.file_size, + complete_emitted: AtomicBool::new(false), + }); incomplete_paths.push((incomplete, file.path.clone())); } - group - .finish() - .await - .map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; + let tracked = Arc::new(tracked_vec); + let poll_handle = spawn_download_progress_poller(progress, &group, Arc::clone(&tracked)); + + let result = group.finish().await; + if let Some(h) = poll_handle { + h.abort(); + } + result.map_err(|e| HFError::Other(format!("Xet batch download failed: {e}")))?; + emit_remaining_completes(progress, &tracked); for (incomplete, final_path) in &incomplete_paths { tokio::fs::rename(incomplete, final_path).await?; @@ -325,7 +470,12 @@ impl HFRepository { /// Upload files using the xet protocol. /// Fetches a write token and uses xet-session's UploadCommit. /// Returns the XetFileInfo (hash + size) for each uploaded file. - pub(crate) async fn xet_upload(&self, files: &[(String, AddSource)], revision: &str) -> Result> { + pub(crate) async fn xet_upload( + &self, + files: &[(String, AddSource)], + revision: &str, + progress: &Progress, + ) -> Result> { let repo_path = self.repo_path(); let repo_type = Some(self.repo_type); tracing::info!(repo = repo_path.as_str(), "fetching xet write token"); @@ -357,29 +507,119 @@ impl HFRepository { tracing::info!("xet upload commit built, queuing file uploads"); let mut task_ids_in_order = Vec::with_capacity(files.len()); + let mut handles: Vec = Vec::with_capacity(files.len()); + let mut item_name_to_repo_path: HashMap = HashMap::with_capacity(files.len()); for (path_in_repo, source) in files { tracing::info!(path = path_in_repo.as_str(), "queuing xet upload"); let handle = match source { - AddSource::File(path) => commit - .upload_from_path(path.clone(), Sha256Policy::Compute) - .await - .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))?, - AddSource::Bytes(bytes) => commit - .upload_bytes(bytes.clone(), Sha256Policy::Compute, None) - .await - .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))?, + AddSource::File(path) => { + // Mimic xet-core's `std::path::absolute()` logic to derive the + // item_name that will appear in ItemProgressReport. + // See: xet-data upload_commit.rs XetUploadCommitInner::upload_from_path + if let Ok(abs) = std::path::absolute(path) { + if let Some(s) = abs.to_str() { + item_name_to_repo_path.insert(s.to_owned(), path_in_repo.clone()); + } else { + tracing::warn!(path = ?abs, "non-UTF-8 path; per-file progress unavailable"); + } + } + commit + .upload_from_path(path.clone(), Sha256Policy::Compute) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, + AddSource::Bytes(bytes) => { + item_name_to_repo_path.insert(path_in_repo.clone(), path_in_repo.clone()); + commit + .upload_bytes(bytes.clone(), Sha256Policy::Compute, Some(path_in_repo.clone())) + .await + .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))? + }, }; task_ids_in_order.push(handle.task_id()); + handles.push(handle); } tracing::info!(file_count = files.len(), "committing xet uploads"); + let shared_handles: Arc> = Arc::new(handles); + let shared_name_map: Arc> = Arc::new(item_name_to_repo_path); + + let poll_handle = progress.as_ref().map(|handler| { + let handler = handler.clone(); + let commit = commit.clone(); + let poll_handles = Arc::clone(&shared_handles); + let poll_name_map = Arc::clone(&shared_name_map); + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let report = commit.progress(); + let file_progress: Vec = poll_handles + .iter() + .filter_map(|h| { + let item = h.progress()?; + let repo_path = poll_name_map.get(&item.item_name)?; + let status = if item.bytes_completed >= item.total_bytes && item.total_bytes > 0 { + FileStatus::Complete + } else if item.bytes_completed > 0 { + FileStatus::InProgress + } else { + FileStatus::Started + }; + Some(FileProgress { + filename: repo_path.clone(), + bytes_completed: item.bytes_completed, + total_bytes: item.total_bytes, + status, + }) + }) + .collect(); + handler.on_progress(&ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: report.total_bytes_completed, + total_bytes: report.total_bytes, + bytes_per_sec: report.total_bytes_completion_rate, + transfer_bytes_completed: report.total_transfer_bytes_completed, + transfer_bytes: report.total_transfer_bytes, + transfer_bytes_per_sec: report.total_transfer_bytes_completion_rate, + files: file_progress, + })); + } + }) + }); let results = commit .commit() .await .map_err(|e| HFError::Other(format!("Xet upload failed: {e}")))?; + if let Some(h) = poll_handle { + h.abort(); + } tracing::info!("xet upload commit complete"); + let final_files: Vec = files + .iter() + .map(|(path_in_repo, _)| FileProgress { + filename: path_in_repo.clone(), + bytes_completed: 0, + total_bytes: 0, + status: FileStatus::Complete, + }) + .collect(); + + progress::emit( + progress, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Uploading, + bytes_completed: results.progress.total_bytes_completed, + total_bytes: results.progress.total_bytes, + bytes_per_sec: results.progress.total_bytes_completion_rate, + transfer_bytes_completed: results.progress.total_transfer_bytes_completed, + transfer_bytes: results.progress.total_transfer_bytes, + transfer_bytes_per_sec: results.progress.total_transfer_bytes_completion_rate, + files: final_files, + }), + ); + let mut xet_file_infos = Vec::with_capacity(files.len()); for task_id in &task_ids_in_order { let metadata: &XetFileMetadata = results diff --git a/huggingface_hub/tests/download_test.rs b/huggingface_hub/tests/download_test.rs index 3d84e45..05bfa4f 100644 --- a/huggingface_hub/tests/download_test.rs +++ b/huggingface_hub/tests/download_test.rs @@ -1,34 +1,83 @@ -//! Integration tests for downloading files from the Hub. +//! Integration tests for downloads and progress tracking. //! -//! Tests regular (non-xet) HTTP downloads of small files. -//! Requires HF_TOKEN environment variable. +//! Read-only tests (downloads from hardcoded repos) use **prod** (huggingface.co). +//! Write tests (upload progress) create temporary repos on **hub-ci** and require HF_TEST_WRITE=1. //! -//! Run: source ~/hf/prod_token && cargo test -p huggingface-hub --test download_test +//! Run read-only: HF_TOKEN=hf_xxx cargo test -p huggingface-hub --test download_test +//! Run all: HF_TOKEN=hf_xxx HF_TEST_WRITE=1 cargo test -p huggingface-hub --test download_test +//! +//! CI: read-only tests use HF_PROD_TOKEN, write tests use HF_CI_TOKEN against hub-ci. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use futures::StreamExt; use huggingface_hub::repository::HFRepository; use huggingface_hub::test_utils::*; -use huggingface_hub::{HFClient, HFClientBuilder, RepoDownloadFileParams, RepoDownloadFileStreamParams}; +use huggingface_hub::{ + AddSource, CommitOperation, CreateRepoParams, DeleteRepoParams, DownloadEvent, FileStatus, HFClient, + HFClientBuilder, ProgressEvent, ProgressHandler, RepoCreateCommitParams, RepoDownloadFileParams, + RepoDownloadFileStreamParams, RepoSnapshotDownloadParams, RepoUploadFileParams, UploadEvent, UploadPhase, +}; use sha2::{Digest, Sha256}; -fn api() -> Option { +fn prod_api() -> Option { if is_ci() { let token = resolve_prod_token()?; - Some( - HFClientBuilder::new() - .token(token) - .endpoint(PROD_ENDPOINT) - .build() - .expect("Failed to create HFClient"), - ) + Some(build_client(&token, PROD_ENDPOINT)) } else { - if std::env::var(HF_TOKEN).is_err() { - return None; - } - Some(HFClientBuilder::new().build().expect("Failed to create HFClient")) + default_api() } } +fn hub_ci_api() -> Option { + if is_ci() { + let token = std::env::var(HF_CI_TOKEN).ok()?; + Some(build_client(&token, HUB_CI_ENDPOINT)) + } else { + default_api() + } +} + +fn default_api() -> Option { + let token = std::env::var(HF_TOKEN).ok()?; + let endpoint = std::env::var(HF_ENDPOINT).unwrap_or_else(|_| PROD_ENDPOINT.to_string()); + Some(build_client(&token, &endpoint)) +} + +fn build_client(token: &str, endpoint: &str) -> HFClient { + HFClientBuilder::new() + .token(token) + .endpoint(endpoint) + .build() + .expect("Failed to create HFClient") +} + +fn uuid_short() -> String { + format!("{:016x}", rand::random::()) +} + +async fn cached_username(api: &HFClient) -> String { + api.whoami().await.expect("whoami failed").username +} + +async fn create_test_repo(api: &HFClient) -> String { + let username = cached_username(api).await; + let repo_id = format!("{}/hfrs-progress-test-{}", username, uuid_short()); + let params = CreateRepoParams::builder() + .repo_id(&repo_id) + .private(true) + .exist_ok(false) + .build(); + api.create_repo(¶ms).await.expect("create_repo failed"); + repo_id +} + +async fn delete_test_repo(api: &HFClient, repo_id: &str) { + let params = DeleteRepoParams::builder().repo_id(repo_id).build(); + let _ = api.delete_repo(¶ms).await; +} + fn test_model_parts() -> (&'static str, &'static str) { ("openai-community", "gpt2") } @@ -47,7 +96,7 @@ fn dataset(api: &HFClient, owner: &str, name: &str) -> HFRepository { #[tokio::test] async fn test_download_small_json_file() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -69,7 +118,7 @@ async fn test_download_small_json_file() { #[tokio::test] async fn test_download_preserves_subdirectory_structure() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -89,7 +138,7 @@ async fn test_download_preserves_subdirectory_structure() { #[tokio::test] async fn test_download_with_specific_revision() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -112,7 +161,7 @@ async fn test_download_with_specific_revision() { #[tokio::test] async fn test_download_dataset_file() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_dataset_parts(); @@ -133,7 +182,7 @@ async fn test_download_dataset_file() { #[tokio::test] async fn test_download_nonexistent_file_returns_error() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -151,7 +200,7 @@ async fn test_download_nonexistent_file_returns_error() { #[tokio::test] async fn test_download_from_nonexistent_repo_returns_error() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let result = model(&api, "this-user-does-not-exist-99999", "this-repo-does-not-exist") @@ -168,7 +217,7 @@ async fn test_download_from_nonexistent_repo_returns_error() { #[tokio::test] async fn test_download_multiple_files_to_same_dir() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); let repo = model(&api, owner, name); @@ -192,7 +241,7 @@ async fn test_download_multiple_files_to_same_dir() { #[tokio::test] async fn test_download_file_content_is_deterministic() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir1 = tempfile::tempdir().unwrap(); let dir2 = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -219,7 +268,7 @@ async fn test_download_file_content_is_deterministic() { #[tokio::test] async fn test_download_overwrites_existing_file() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let dir = tempfile::tempdir().unwrap(); let (owner, name) = test_model_parts(); @@ -245,7 +294,7 @@ async fn test_download_overwrites_existing_file() { #[tokio::test] async fn test_download_stream_full_file() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let (owner, name) = test_model_parts(); let repo = model(&api, owner, name); @@ -268,7 +317,7 @@ async fn test_download_stream_full_file() { #[tokio::test] async fn test_download_stream_range_first_bytes() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let (owner, name) = test_model_parts(); let repo = model(&api, owner, name); @@ -295,7 +344,7 @@ async fn test_download_stream_range_first_bytes() { #[tokio::test] async fn test_download_stream_range_middle_bytes() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let (owner, name) = test_model_parts(); let repo = model(&api, owner, name); @@ -335,7 +384,7 @@ async fn test_download_stream_range_middle_bytes() { #[tokio::test] async fn test_download_stream_range_content_matches_full_download() { - let Some(api) = api() else { return }; + let Some(api) = prod_api() else { return }; let (owner, name) = test_model_parts(); let repo = model(&api, owner, name); let dir = tempfile::tempdir().unwrap(); @@ -372,3 +421,320 @@ async fn test_download_stream_range_content_matches_full_download() { assert_eq!(streamed, &full_bytes[..range_end as usize]); } + +// --- Progress tracking tests --- + +struct RecordingHandler { + events: Mutex>, +} + +impl RecordingHandler { + fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + } + } + + fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } +} + +impl ProgressHandler for RecordingHandler { + fn on_progress(&self, event: &ProgressEvent) { + self.events.lock().unwrap().push(event.clone()); + } +} + +#[tokio::test] +async fn test_download_file_with_progress_to_local_dir() { + let Some(api) = prod_api() else { return }; + let (owner, name) = test_model_parts(); + let repo = model(&api, owner, name); + + let handler = Arc::new(RecordingHandler::new()); + + let dir = tempfile::tempdir().unwrap(); + let params = RepoDownloadFileParams::builder() + .filename("config.json") + .local_dir(dir.path().to_path_buf()) + .progress(Some(handler.clone())) + .build(); + + let path = repo.download_file(¶ms).await.unwrap(); + assert!(path.exists()); + + let events = handler.events(); + assert!(!events.is_empty(), "should have received progress events"); + + // First event should be Download(Start) + assert!( + matches!(&events[0], ProgressEvent::Download(DownloadEvent::Start { total_files: 1, .. })), + "first event should be Download(Start), got {:?}", + &events[0] + ); + + // Last event should be Download(Complete) + assert!( + matches!(events.last().unwrap(), ProgressEvent::Download(DownloadEvent::Complete)), + "last event should be Download(Complete)" + ); + + // Should have at least one Progress event with InProgress or Complete + let has_progress = events + .iter() + .any(|e| matches!(e, ProgressEvent::Download(DownloadEvent::Progress { .. }))); + assert!(has_progress, "should have at least one Progress event"); + + // Should have a Complete file status + let has_file_complete = events.iter().any(|e| { + if let ProgressEvent::Download(DownloadEvent::Progress { files }) = e { + files.iter().any(|f| f.status == FileStatus::Complete) + } else { + false + } + }); + assert!(has_file_complete, "should have a file Complete status event"); +} + +#[tokio::test] +async fn test_download_file_with_progress_to_cache() { + let Some(api) = prod_api() else { return }; + let (owner, name) = test_model_parts(); + let repo = model(&api, owner, name); + + let handler = Arc::new(RecordingHandler::new()); + + let params = RepoDownloadFileParams::builder() + .filename("config.json") + .force_download(true) + .progress(Some(handler.clone())) + .build(); + + let path = repo.download_file(¶ms).await.unwrap(); + assert!(path.exists()); + + let events = handler.events(); + assert!(!events.is_empty(), "should have received progress events"); + + assert!( + matches!(&events[0], ProgressEvent::Download(DownloadEvent::Start { total_files: 1, .. })), + "first event should be Download(Start)" + ); + assert!( + matches!(events.last().unwrap(), ProgressEvent::Download(DownloadEvent::Complete)), + "last event should be Download(Complete)" + ); +} + +#[tokio::test] +async fn test_download_with_no_progress_handler() { + let Some(api) = prod_api() else { return }; + let (owner, name) = test_model_parts(); + let repo = model(&api, owner, name); + + let dir = tempfile::tempdir().unwrap(); + let params = RepoDownloadFileParams::builder() + .filename("config.json") + .local_dir(dir.path().to_path_buf()) + .build(); + + let path = repo.download_file(¶ms).await.unwrap(); + assert!(path.exists()); +} + +// --- Upload progress tests (write to hub-ci) --- + +fn repo_from_id(api: &HFClient, repo_id: &str) -> HFRepository { + let parts: Vec<&str> = repo_id.splitn(2, '/').collect(); + api.model(parts[0], parts[1]) +} + +#[tokio::test] +async fn test_upload_file_with_progress() { + let Some(api) = hub_ci_api() else { return }; + if !write_enabled() { + return; + } + let repo_id = create_test_repo(&api).await; + let repo = repo_from_id(&api, &repo_id); + + let handler = Arc::new(RecordingHandler::new()); + + let result = repo + .upload_file( + &RepoUploadFileParams::builder() + .source(AddSource::Bytes(b"hello from progress test".to_vec())) + .path_in_repo("progress_test.txt") + .commit_message("upload with progress tracking") + .progress(Some(handler.clone())) + .build(), + ) + .await; + + delete_test_repo(&api, &repo_id).await; + let commit = result.unwrap(); + assert!(commit.commit_oid.is_some()); + + let events = handler.events(); + assert!(!events.is_empty(), "should have received upload progress events"); + + assert!( + matches!(&events[0], ProgressEvent::Upload(UploadEvent::Start { total_files: 1, .. })), + "first event should be Upload(Start), got {:?}", + &events[0] + ); + + assert!( + matches!(events.last().unwrap(), ProgressEvent::Upload(UploadEvent::Complete)), + "last event should be Upload(Complete)" + ); + + let has_preparing = events.iter().any(|e| { + matches!( + e, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Preparing, + .. + }) + ) + }); + assert!(has_preparing, "should have a Preparing phase event"); + + let has_committing = events.iter().any(|e| { + matches!( + e, + ProgressEvent::Upload(UploadEvent::Progress { + phase: UploadPhase::Committing, + .. + }) + ) + }); + assert!(has_committing, "should have a Committing phase event"); +} + +#[tokio::test] +async fn test_create_commit_with_progress_multiple_files() { + let Some(api) = hub_ci_api() else { return }; + if !write_enabled() { + return; + } + let repo_id = create_test_repo(&api).await; + let repo = repo_from_id(&api, &repo_id); + + let handler = Arc::new(RecordingHandler::new()); + + let result = repo + .create_commit( + &RepoCreateCommitParams::builder() + .operations(vec![ + CommitOperation::Add { + path_in_repo: "file_a.txt".to_string(), + source: AddSource::Bytes(b"content a".to_vec()), + }, + CommitOperation::Add { + path_in_repo: "file_b.txt".to_string(), + source: AddSource::Bytes(b"content b".to_vec()), + }, + ]) + .commit_message("multi-file commit with progress") + .progress(Some(handler.clone())) + .build(), + ) + .await; + + delete_test_repo(&api, &repo_id).await; + let commit = result.unwrap(); + assert!(commit.commit_oid.is_some()); + + let events = handler.events(); + assert!(!events.is_empty(), "should have received upload progress events"); + + if let ProgressEvent::Upload(UploadEvent::Start { + total_files, + total_bytes, + }) = &events[0] + { + assert_eq!(*total_files, 2); + assert_eq!(*total_bytes, 18); // "content a" + "content b" = 9 + 9 + } else { + panic!("first event should be Upload(Start), got {:?}", &events[0]); + } + + assert!( + matches!(events.last().unwrap(), ProgressEvent::Upload(UploadEvent::Complete)), + "last event should be Upload(Complete)" + ); +} + +#[tokio::test] +async fn test_upload_with_no_progress_handler() { + let Some(api) = hub_ci_api() else { return }; + if !write_enabled() { + return; + } + let repo_id = create_test_repo(&api).await; + let repo = repo_from_id(&api, &repo_id); + + let result = repo + .upload_file( + &RepoUploadFileParams::builder() + .source(AddSource::Bytes(b"no handler test".to_vec())) + .path_in_repo("no_handler.txt") + .commit_message("upload without progress handler") + .build(), + ) + .await; + + delete_test_repo(&api, &repo_id).await; + result.unwrap(); +} + +#[tokio::test] +async fn test_snapshot_download_exactly_one_complete_per_file() { + let Some(api) = prod_api() else { return }; + let (owner, name) = test_model_parts(); + let repo = model(&api, owner, name); + + let handler = Arc::new(RecordingHandler::new()); + let dir = tempfile::tempdir().unwrap(); + + let params = RepoSnapshotDownloadParams::builder() + .local_dir(dir.path().to_path_buf()) + .allow_patterns(vec!["*.json".to_string()]) + .force_download(true) + .progress(Some(handler.clone())) + .build(); + + repo.snapshot_download(¶ms).await.unwrap(); + + let events = handler.events(); + + // Count Complete events per filename + let mut complete_counts: HashMap = HashMap::new(); + for event in &events { + if let ProgressEvent::Download(DownloadEvent::Progress { files }) = event { + for fp in files { + if fp.status == FileStatus::Complete { + *complete_counts.entry(fp.filename.clone()).or_default() += 1; + } + } + } + } + + assert!(!complete_counts.is_empty(), "should have at least one file Complete event"); + + for (filename, count) in &complete_counts { + assert_eq!(*count, 1, "file '{filename}' had {count} Complete events, expected exactly 1"); + } + + // The files_bar count should match: total_files from Start == number of distinct Complete files + if let Some(ProgressEvent::Download(DownloadEvent::Start { total_files, .. })) = events.first() { + assert_eq!( + *total_files, + complete_counts.len(), + "total_files in Start ({total_files}) should match number of completed files ({})", + complete_counts.len() + ); + } +}