Skip to content

Commit b81bb38

Browse files
committed
feat(hf): retry requests by default
1 parent 14350e8 commit b81bb38

File tree

6 files changed

+82
-50
lines changed

6 files changed

+82
-50
lines changed

core/services/huggingface/src/backend.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ impl HfBuilder {
129129
self.config.xet = xet;
130130
self
131131
}
132+
133+
/// Set the maximum number of retries for commit operations.
134+
///
135+
/// Retries on commit conflicts (HTTP 412) and transient server
136+
/// errors (HTTP 5xx). Default is 3.
137+
pub fn max_retries(mut self, max_retries: usize) -> Self {
138+
self.config.max_retries = Some(max_retries);
139+
self
140+
}
132141
}
133142

134143
impl Builder for HfBuilder {
@@ -198,6 +207,7 @@ impl Builder for HfBuilder {
198207
root,
199208
token,
200209
endpoint,
210+
max_retries: self.config.max_retries.unwrap_or(3),
201211
#[cfg(feature = "xet")]
202212
xet_enabled: self.config.xet,
203213
}),
@@ -237,9 +247,8 @@ impl Access for HfBackend {
237247
}
238248

239249
async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
240-
let l = HfLister::new(self.core.clone(), path.to_string(), args.recursive());
241-
242-
Ok((RpList::default(), oio::PageLister::new(l)))
250+
let lister = HfLister::new(self.core.clone(), path.to_string(), args.recursive());
251+
Ok((RpList::default(), oio::PageLister::new(lister)))
243252
}
244253

245254
async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
@@ -248,12 +257,11 @@ impl Access for HfBackend {
248257
}
249258

250259
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
260+
let deleter = HfDeleter::new(self.core.clone());
261+
let delete_max_size = self.core.info.full_capability().delete_max_size;
251262
Ok((
252263
RpDelete::default(),
253-
oio::BatchDeleter::new(
254-
HfDeleter::new(self.core.clone()),
255-
self.core.info.full_capability().delete_max_size,
256-
),
264+
oio::BatchDeleter::new(deleter, delete_max_size),
257265
))
258266
}
259267
}

core/services/huggingface/src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ pub struct HfConfig {
6060
/// check for XET-backed files and use the XET protocol for
6161
/// downloading. Default is false.
6262
pub xet: bool,
63+
/// Maximum number of retries for commit operations.
64+
///
65+
/// Retries on commit conflicts (HTTP 412) and transient server
66+
/// errors (HTTP 5xx). Default is 3.
67+
pub max_retries: Option<usize>,
6368
}
6469

6570
impl Debug for HfConfig {

core/services/huggingface/src/core.rs

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use std::sync::Arc;
2121
use bytes::Buf;
2222
use bytes::Bytes;
2323
use http::Request;
24-
use http::StatusCode;
2524
use http::header;
2625
use serde::Deserialize;
2726

@@ -184,6 +183,7 @@ pub struct HfCore {
184183
pub root: String,
185184
pub token: Option<String>,
186185
pub endpoint: String,
186+
pub max_retries: usize,
187187

188188
#[cfg(feature = "xet")]
189189
pub xet_enabled: bool,
@@ -238,17 +238,42 @@ impl HfCore {
238238
///
239239
/// Returns the response parts (status, headers, etc.) alongside the
240240
/// deserialized body so callers can inspect headers when needed.
241+
///
242+
/// When `max_retries` > 1, retries on commit conflicts (HTTP 412) and
243+
/// transient server errors (HTTP 5xx), matching the behavior of the
244+
/// official HuggingFace Hub client.
241245
async fn send_request<T: serde::de::DeserializeOwned>(
242246
&self,
243247
req: Request<Buffer>,
248+
max_retries: usize,
244249
) -> Result<(http::response::Parts, T)> {
245-
let resp = self.info.http_client().send(req).await?;
246-
if !resp.status().is_success() {
247-
return Err(parse_error(resp));
250+
let client = self.info.http_client();
251+
let mut attempt = 0;
252+
loop {
253+
match client.send(req.clone()).await {
254+
Ok(resp) if resp.status().is_success() => {
255+
let (parts, body) = resp.into_parts();
256+
let parsed = serde_json::from_reader(body.reader())
257+
.map_err(new_json_deserialize_error)?;
258+
return Ok((parts, parsed));
259+
}
260+
Ok(resp) => {
261+
attempt += 1;
262+
let err = parse_error(resp);
263+
let retryable =
264+
err.kind() == ErrorKind::ConditionNotMatch || err.is_temporary();
265+
if attempt >= max_retries || !retryable {
266+
return Err(err);
267+
}
268+
}
269+
Err(err) => {
270+
attempt += 1;
271+
if attempt >= max_retries || !err.is_temporary() {
272+
return Err(err);
273+
}
274+
}
275+
}
248276
}
249-
let (parts, body) = resp.into_parts();
250-
let parsed = serde_json::from_reader(body.reader()).map_err(new_json_deserialize_error)?;
251-
Ok((parts, parsed))
252277
}
253278

254279
pub async fn path_info(&self, path: &str) -> Result<PathInfo> {
@@ -261,7 +286,7 @@ impl HfCore {
261286
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
262287
.body(Buffer::from(Bytes::from(form_body)))
263288
.map_err(new_request_build_error)?;
264-
let (_, mut files) = self.send_request::<Vec<PathInfo>>(req).await?;
289+
let (_, mut files) = self.send_request::<Vec<PathInfo>>(req, 1).await?;
265290

266291
// NOTE: if the file is not found, the server will return 200 with an empty array
267292
if files.is_empty() {
@@ -284,7 +309,7 @@ impl HfCore {
284309
.request(http::Method::GET, &url, Operation::List)
285310
.body(Buffer::new())
286311
.map_err(new_request_build_error)?;
287-
let (parts, files) = self.send_request::<Vec<PathInfo>>(req).await?;
312+
let (parts, files) = self.send_request::<Vec<PathInfo>>(req, 1).await?;
288313

289314
let next_cursor = parts
290315
.headers
@@ -302,16 +327,17 @@ impl HfCore {
302327
.request(http::Method::GET, &url, Operation::Read)
303328
.body(Buffer::new())
304329
.map_err(new_request_build_error)?;
305-
let (_, token) = self.send_request(req).await?;
330+
let (_, token) = self.send_request(req, 1).await?;
306331
Ok(token)
307332
}
308333

309334
/// Issue a HEAD request and extract XET file info (hash and size).
310335
///
311-
/// Uses a custom HTTP client that does NOT follow redirects so we can
312-
/// inspect response headers (e.g. `X-Xet-Hash`) from the 302 response.
313-
///
314336
/// Returns `None` if the `X-Xet-Hash` header is absent or empty.
337+
///
338+
/// NOTE: Cannot use `send_request` here because we need a custom
339+
/// no-redirect HTTP client to inspect headers (e.g. `X-Xet-Hash`)
340+
/// from the 302 response, and the response is not JSON.
315341
#[cfg(feature = "xet")]
316342
pub(super) async fn get_xet_file(&self, path: &str) -> Result<Option<XetFile>> {
317343
let uri = self.repo.uri(&self.root, path);
@@ -330,7 +356,17 @@ impl HfCore {
330356
.body(Buffer::new())
331357
.map_err(new_request_build_error)?;
332358

333-
let resp = client.send(req).await?;
359+
// Retry on transient errors, same as send_request.
360+
let mut attempt = 0;
361+
let resp = loop {
362+
let resp = client.send(req.clone()).await?;
363+
364+
attempt += 1;
365+
let retryable = resp.status().is_server_error();
366+
if attempt >= self.max_retries || !retryable {
367+
break resp;
368+
}
369+
};
334370

335371
let hash = resp
336372
.headers()
@@ -385,11 +421,15 @@ impl HfCore {
385421
.body(Buffer::from(json_body))
386422
.map_err(new_request_build_error)?;
387423

388-
let (_, resp) = self.send_request(req).await?;
424+
let (_, resp) = self.send_request(req, 1).await?;
389425
Ok(resp)
390426
}
391427

392428
/// Commit file changes (uploads and/or deletions) to the repository.
429+
///
430+
/// Retries on commit conflicts (HTTP 412) and transient server errors
431+
/// (HTTP 5xx), matching the behavior of the official HuggingFace Hub
432+
/// client.
393433
pub(super) async fn commit_files(
394434
&self,
395435
regular_files: Vec<CommitFile>,
@@ -411,7 +451,6 @@ impl HfCore {
411451
.or_else(|| deleted_files.first().map(|f| f.path.as_str()))
412452
.ok_or_else(|| Error::new(ErrorKind::Unexpected, "no files to commit"))?;
413453

414-
let client = self.info.http_client();
415454
let uri = self.repo.uri(&self.root, first_path);
416455
let url = uri.commit_url(&self.endpoint);
417456

@@ -431,11 +470,9 @@ impl HfCore {
431470
.body(Buffer::from(json_body))
432471
.map_err(new_request_build_error)?;
433472

434-
let resp = client.send(req).await?;
435-
match resp.status() {
436-
StatusCode::OK | StatusCode::CREATED => Ok(()),
437-
_ => Err(parse_error(resp)),
438-
}
473+
self.send_request::<serde_json::Value>(req, self.max_retries)
474+
.await?;
475+
Ok(())
439476
}
440477
}
441478

@@ -537,6 +574,7 @@ pub(crate) mod test_utils {
537574
root: "/".to_string(),
538575
token: None,
539576
endpoint: endpoint.to_string(),
577+
max_retries: 3,
540578
#[cfg(feature = "xet")]
541579
xet_enabled: false,
542580
};

core/services/huggingface/src/deleter.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ mod tests {
8585
root: "/".to_string(),
8686
token: std::env::var("HF_OPENDAL_TOKEN").ok(),
8787
endpoint: "https://huggingface.co".to_string(),
88+
max_retries: 3,
8889
#[cfg(feature = "xet")]
8990
xet_enabled: false,
9091
}

core/services/huggingface/src/reader.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ mod tests {
166166
root: "/".to_string(),
167167
token: None,
168168
endpoint: "https://huggingface.co".to_string(),
169+
max_retries: 3,
169170
#[cfg(feature = "xet")]
170171
xet_enabled: _xet,
171172
}

core/services/huggingface/src/writer.rs

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,29 +132,7 @@ impl HfWriter {
132132
})
133133
}
134134

135-
/// Upload file and commit based on determined mode.
136-
///
137-
/// Retries on commit conflicts (HTTP 412) and transient server errors
138-
/// (HTTP 5xx), matching the behavior of the official HuggingFace Hub
139-
/// client.
140135
async fn upload_and_commit(&self, body: Buffer) -> Result<Metadata> {
141-
const MAX_RETRIES: usize = 3;
142-
143-
let mut last_err = None;
144-
for _ in 0..MAX_RETRIES {
145-
match self.try_upload_and_commit(body.clone()).await {
146-
Ok(meta) => return Ok(meta),
147-
Err(err) if err.kind() == ErrorKind::ConditionNotMatch || err.is_temporary() => {
148-
last_err = Some(err);
149-
continue;
150-
}
151-
Err(err) => return Err(err),
152-
}
153-
}
154-
Err(last_err.unwrap())
155-
}
156-
157-
async fn try_upload_and_commit(&self, body: Buffer) -> Result<Metadata> {
158136
#[cfg_attr(not(feature = "xet"), allow(unused_variables))]
159137
let mode = Self::determine_upload_mode(&self.core, &self.path, &body).await?;
160138

@@ -216,6 +194,7 @@ mod tests {
216194
root: "/".to_string(),
217195
token: std::env::var("HF_OPENDAL_TOKEN").ok(),
218196
endpoint: "https://huggingface.co".to_string(),
197+
max_retries: 3,
219198
#[cfg(feature = "xet")]
220199
xet_enabled: _xet,
221200
}

0 commit comments

Comments
 (0)