Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/goose-server/src/routes/local_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ pub struct SearchQuery {
pub struct RepoVariantsResponse {
pub variants: Vec<HfQuantVariant>,
pub recommended_index: Option<usize>,
pub available_memory_bytes: u64,
pub downloaded_quants: Vec<String>,
}

#[utoipa::path(
Expand Down Expand Up @@ -232,9 +234,23 @@ pub async fn get_repo_files(
let available_memory = available_inference_memory_bytes(&state.inference_runtime);
let recommended_index = hf_models::recommend_variant(&variants, available_memory);

let downloaded_quants = {
let registry = get_registry()
.lock()
.map_err(|_| ErrorResponse::internal("Failed to acquire registry lock"))?;
registry
.list_models()
.iter()
.filter(|m| m.repo_id == repo_id && m.is_downloaded())
.map(|m| m.quantization.clone())
.collect()
};

Ok(Json(RepoVariantsResponse {
variants,
recommended_index,
available_memory_bytes: available_memory,
downloaded_quants,
}))
}

Expand Down
273 changes: 219 additions & 54 deletions crates/goose/src/download_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ impl DownloadManager {
Ok(())
}

const MAX_RETRIES: u32 = 10;
const RETRY_BASE_DELAY: std::time::Duration = std::time::Duration::from_secs(2);
const RETRY_MAX_DELAY: std::time::Duration = std::time::Duration::from_secs(60);

fn is_cancelled(downloads: &DownloadMap, model_id: &str) -> bool {
if let Ok(downloads) = downloads.lock() {
if let Some(progress) = downloads.get(model_id) {
return progress.status == DownloadStatus::Cancelled;
}
}
false
}

async fn download_file(
url: &str,
destination: &PathBuf,
Expand All @@ -196,85 +209,237 @@ impl DownloadManager {
) -> Result<(), anyhow::Error> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.read_timeout(std::time::Duration::from_secs(60))
.read_timeout(std::time::Duration::from_secs(120))
.build()?;
let mut response = client.get(url).send().await?;

if !response.status().is_success() {
anyhow::bail!("Failed to download: HTTP {}", response.status());
}

let total_bytes = response.content_length().unwrap_or(0);
let partial_path = partial_path_for(destination);
let mut retries = 0u32;

{
if let Ok(mut downloads) = downloads.lock() {
if let Some(progress) = downloads.get_mut(model_id) {
progress.total_bytes = total_bytes;
// Check for existing partial file to resume
let mut bytes_downloaded: u64 = if partial_path.exists() {
tokio::fs::metadata(&partial_path).await?.len()
} else {
0
};

// Get total size with a HEAD request first (so we know even before first chunk)
let total_bytes = {
let head_resp = client
.head(url)
.send()
.await
.ok()
.and_then(|r| r.content_length());
head_resp.unwrap_or(0)
};

if let Ok(mut dl) = downloads.lock() {
if let Some(progress) = dl.get_mut(model_id) {
progress.total_bytes = total_bytes;
progress.bytes_downloaded = bytes_downloaded;
if total_bytes > 0 {
progress.progress_percent =
(bytes_downloaded as f64 / total_bytes as f64 * 100.0) as f32;
}
}
}

let partial_path = partial_path_for(destination);
let mut file = tokio::fs::File::create(&partial_path).await?;
let mut bytes_downloaded = 0u64;
// If already fully downloaded from a previous partial, just rename
if total_bytes > 0 && bytes_downloaded >= total_bytes {
tokio::fs::rename(&partial_path, destination).await?;
return Ok(());
}

let start_time = std::time::Instant::now();
// bytes_at_start tracks how many bytes we had when timing began (for speed calc)
let bytes_at_start = bytes_downloaded;

while let Some(chunk) = response.chunk().await? {
// Check if cancelled
let should_cancel = {
if let Ok(downloads) = downloads.lock() {
if let Some(progress) = downloads.get(model_id) {
progress.status == DownloadStatus::Cancelled
} else {
false
loop {
if Self::is_cancelled(downloads, model_id) {
let _ = tokio::fs::remove_file(&partial_path).await;
anyhow::bail!("Download cancelled");
}

// Build request with Range header for resume
let mut request = client.get(url);
if bytes_downloaded > 0 {
request = request.header("Range", format!("bytes={}-", bytes_downloaded));
}

let response = match request.send().await {
Ok(r) => r,
Err(e) => {
if retries >= Self::MAX_RETRIES {
anyhow::bail!("Download failed after {} retries: {}", retries, e);
}
} else {
false
retries += 1;
let delay = std::cmp::min(
Self::RETRY_BASE_DELAY * 2u32.saturating_pow(retries - 1),
Self::RETRY_MAX_DELAY,
);
info!(model_id = %model_id, retry = retries, delay_secs = ?delay.as_secs(), error = %e, "Retrying download after connection error");
tokio::time::sleep(delay).await;
continue;
}
};

if should_cancel {
let status = response.status();
if status == reqwest::StatusCode::RANGE_NOT_SATISFIABLE {
// Server can't satisfy range — file may be complete or something is off.
// If partial file is at least total_bytes, treat as done.
if total_bytes > 0 && bytes_downloaded >= total_bytes {
break;
}
// Otherwise restart from scratch
bytes_downloaded = 0;
let _ = tokio::fs::remove_file(&partial_path).await;
anyhow::bail!("Download cancelled");
continue;
}

file.write_all(&chunk).await?;
bytes_downloaded += chunk.len() as u64;
if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
if retries >= Self::MAX_RETRIES {
anyhow::bail!("Failed to download: HTTP {}", status);
Comment on lines +299 to +301
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid retrying permanent 4xx download failures

The retry block treats all non-success HTTP statuses as retryable, including permanent client errors like 401/403/404. In those cases the download will back off for multiple minutes before surfacing failure, which delays user feedback and ties up the download slot without any chance of recovery; retries should be limited to transient statuses (for example 408/429/5xx) and fail fast on persistent 4xx.

Useful? React with 👍 / 👎.

}
retries += 1;
let delay = std::cmp::min(
Self::RETRY_BASE_DELAY * 2u32.saturating_pow(retries - 1),
Self::RETRY_MAX_DELAY,
);
info!(model_id = %model_id, retry = retries, http_status = %status, "Retrying download after HTTP error");
tokio::time::sleep(delay).await;
continue;
}

// Update progress
let elapsed = start_time.elapsed().as_secs_f64();
let speed_bps = if elapsed > 0.0 {
Some((bytes_downloaded as f64 / elapsed) as u64)
} else {
None
};
// We sent a Range request but the server ignored it and returned
// the full body (200 OK instead of 206). Truncate and restart so
// we don't append a full copy onto the existing partial data.
if bytes_downloaded > 0 && status == reqwest::StatusCode::OK {
info!(model_id = %model_id, "Server ignored Range header, restarting download from scratch");
bytes_downloaded = 0;
let _ = tokio::fs::remove_file(&partial_path).await;
}

let eta_seconds = if let Some(speed) = speed_bps {
if speed > 0 && total_bytes > 0 {
Some(total_bytes.saturating_sub(bytes_downloaded) / speed)
// Update total_bytes from Content-Range or Content-Length if not yet known
if total_bytes == 0 {
let new_total = if bytes_downloaded > 0 {
// Parse Content-Range: bytes 1234-5678/9999
response
.headers()
.get("content-range")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.rsplit('/').next())
.and_then(|s| s.parse::<u64>().ok())
} else {
None
response.content_length()
};
if let Some(t) = new_total {
if let Ok(mut dl) = downloads.lock() {
if let Some(progress) = dl.get_mut(model_id) {
progress.total_bytes = t;
}
}
}
} else {
None
};
}

// Open file for appending (or create)
let mut file = tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&partial_path)
.await?;

// Truncate to bytes_downloaded in case file grew beyond our tracking
let file_len = tokio::fs::metadata(&partial_path).await?.len();
if file_len != bytes_downloaded {
file.set_len(bytes_downloaded).await?;
}

let mut stream_error = false;
let mut resp = response;

loop {
let chunk_result = resp.chunk().await;
match chunk_result {
Ok(Some(chunk)) => {
if Self::is_cancelled(downloads, model_id) {
let _ = tokio::fs::remove_file(&partial_path).await;
anyhow::bail!("Download cancelled");
}

if let Ok(mut downloads) = downloads.lock() {
if let Some(progress) = downloads.get_mut(model_id) {
progress.bytes_downloaded = bytes_downloaded;
progress.progress_percent = if total_bytes > 0 {
(bytes_downloaded as f64 / total_bytes as f64 * 100.0) as f32
} else {
0.0
};
progress.speed_bps = speed_bps;
progress.eta_seconds = eta_seconds;
file.write_all(&chunk).await?;
bytes_downloaded += chunk.len() as u64;

let elapsed = start_time.elapsed().as_secs_f64();
let bytes_this_session = bytes_downloaded.saturating_sub(bytes_at_start);
let speed_bps = if elapsed > 0.0 {
Some((bytes_this_session as f64 / elapsed) as u64)
} else {
None
};

let current_total = if let Ok(dl) = downloads.lock() {
dl.get(model_id)
.map(|p| p.total_bytes)
.unwrap_or(total_bytes)
} else {
total_bytes
};

let eta_seconds = if let Some(speed) = speed_bps {
if speed > 0 && current_total > 0 {
Some(current_total.saturating_sub(bytes_downloaded) / speed)
} else {
None
}
} else {
None
};

if let Ok(mut dl) = downloads.lock() {
if let Some(progress) = dl.get_mut(model_id) {
progress.bytes_downloaded = bytes_downloaded;
progress.progress_percent = if current_total > 0 {
(bytes_downloaded as f64 / current_total as f64 * 100.0) as f32
} else {
0.0
};
progress.speed_bps = speed_bps;
progress.eta_seconds = eta_seconds;
}
}
}
Ok(None) => break, // Stream finished
Err(e) => {
info!(model_id = %model_id, bytes = bytes_downloaded, error = %e, "Download stream interrupted, will retry");
stream_error = true;
break;
}
}
}

file.flush().await?;
drop(file);

if stream_error {
if retries >= Self::MAX_RETRIES {
anyhow::bail!(
"Download failed after {} retries due to stream interruption",
retries
);
}
retries += 1;
let delay = std::cmp::min(
Self::RETRY_BASE_DELAY * 2u32.saturating_pow(retries - 1),
Self::RETRY_MAX_DELAY,
);
info!(model_id = %model_id, retry = retries, delay_secs = ?delay.as_secs(), "Retrying download with resume");
tokio::time::sleep(delay).await;
continue;
}

break;
}

file.flush().await?;
drop(file);
tokio::fs::rename(&partial_path, destination).await?;
Ok(())
}
Expand Down
Loading
Loading