Skip to content

Commit 082ac8e

Browse files
committed
feat(hf): support multiple upload modes and more thorough testing
1 parent b81bb38 commit 082ac8e

File tree

8 files changed

+1258
-806
lines changed

8 files changed

+1258
-806
lines changed

core/services/huggingface/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ xet = [
3737
"dep:xet-data",
3838
"dep:cas_types",
3939
"dep:xet-utils",
40-
"dep:tokio",
40+
"tokio/sync",
41+
"tokio/rt",
4142
"dep:futures",
4243
"dep:async-trait",
4344
]
@@ -53,6 +54,7 @@ serde = { workspace = true, features = ["derive"] }
5354
serde_json = { workspace = true }
5455
sha2 = "0.10"
5556
tempfile = "3"
57+
tokio = { workspace = true, features = ["time"] }
5658

5759
# XET storage protocol support (optional)
5860
async-trait = { version = "0.1", optional = true }
@@ -61,7 +63,6 @@ futures = { workspace = true, optional = true }
6163
reqwest = { version = "0.12", default-features = false, features = [
6264
"rustls-tls",
6365
], optional = true }
64-
tokio = { workspace = true, features = ["sync", "rt"], optional = true }
6566
xet-data = { package = "data", git = "https://github.com/kszucs/xet-core", branch = "download_bytes", optional = true }
6667
xet-utils = { package = "utils", git = "https://github.com/kszucs/xet-core", branch = "download_bytes", optional = true }
6768

@@ -70,5 +71,6 @@ futures = { workspace = true }
7071
opendal-core = { path = "../../core", version = "0.55.0", features = [
7172
"reqwest-rustls-tls",
7273
] }
74+
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
7375
serde_json = { workspace = true }
7476
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }

core/services/huggingface/src/backend.rs

Lines changed: 217 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,17 @@ impl HfBuilder {
122122

123123
/// Enable XET storage protocol for reads.
124124
///
125-
/// When true and the `xet` feature is compiled in, reads will
126-
/// check for XET-backed files and use the XET protocol for
127-
/// downloading. Default is false.
128-
pub fn xet(mut self, xet: bool) -> Self {
129-
self.config.xet = xet;
125+
/// When the `xet` feature is compiled in, reads will check for
126+
/// XET-backed files and use the XET protocol for downloading.
127+
/// Default is disabled.
128+
pub fn enable_xet(mut self) -> Self {
129+
self.config.xet = true;
130+
self
131+
}
132+
133+
/// Disable XET storage protocol for reads.
134+
pub fn disable_xet(mut self) -> Self {
135+
self.config.xet = false;
130136
self
131137
}
132138

@@ -143,7 +149,6 @@ impl HfBuilder {
143149
impl Builder for HfBuilder {
144150
type Config = HfConfig;
145151

146-
/// Build a HfBackend.
147152
fn build(self) -> Result<impl Access> {
148153
debug!("backend build started: {:?}", &self);
149154

@@ -200,25 +205,31 @@ impl Builder for HfBuilder {
200205
am.into()
201206
};
202207

208+
let repo = HfRepo::new(repo_type, repo_id, Some(revision.clone()));
209+
debug!("backend repo uri: {:?}", repo.uri(&root, ""));
210+
211+
let max_retries = self.config.max_retries.unwrap_or(3);
212+
debug!("backend max_retries: {}", max_retries);
213+
203214
Ok(HfBackend {
204-
core: Arc::new(HfCore {
215+
core: Arc::new(HfCore::new(
205216
info,
206-
repo: HfRepo::new(repo_type, repo_id, Some(revision)),
217+
repo,
207218
root,
208219
token,
209220
endpoint,
210-
max_retries: self.config.max_retries.unwrap_or(3),
221+
max_retries,
211222
#[cfg(feature = "xet")]
212-
xet_enabled: self.config.xet,
213-
}),
223+
self.config.xet,
224+
)?),
214225
})
215226
}
216227
}
217228

218229
/// Backend for Huggingface service
219230
#[derive(Debug, Clone)]
220231
pub struct HfBackend {
221-
core: Arc<HfCore>,
232+
pub(crate) core: Arc<HfCore>,
222233
}
223234

224235
impl Access for HfBackend {
@@ -258,16 +269,108 @@ impl Access for HfBackend {
258269

259270
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
260271
let deleter = HfDeleter::new(self.core.clone());
261-
let delete_max_size = self.core.info.full_capability().delete_max_size;
272+
let max_batch_size = self.core.info.full_capability().delete_max_size;
262273
Ok((
263274
RpDelete::default(),
264-
oio::BatchDeleter::new(deleter, delete_max_size),
275+
oio::BatchDeleter::new(deleter, max_batch_size),
265276
))
266277
}
267278
}
268279

280+
#[cfg(test)]
281+
pub(super) mod test_utils {
282+
use super::HfBuilder;
283+
use opendal_core::Operator;
284+
use opendal_core::layers::HttpClientLayer;
285+
use opendal_core::raw::HttpClient;
286+
287+
/// Create an operator with a fresh HTTP client so parallel tests
288+
/// don't share the global static reqwest client (which causes
289+
/// "dispatch task is gone" errors when runtimes are dropped).
290+
fn finish_operator(op: Operator) -> Operator {
291+
let client = HttpClient::with(reqwest::Client::new());
292+
op.layer(HttpClientLayer::new(client))
293+
}
294+
295+
pub fn testing_credentials() -> (String, String) {
296+
let repo_id = std::env::var("HF_OPENDAL_DATASET").expect("HF_OPENDAL_DATASET must be set");
297+
let token = std::env::var("HF_OPENDAL_TOKEN").expect("HF_OPENDAL_TOKEN must be set");
298+
(repo_id, token)
299+
}
300+
301+
/// Operator for a private dataset requiring HF_OPENDAL_DATASET and HF_OPENDAL_TOKEN.
302+
/// Uses higher max_retries to tolerate concurrent commit conflicts (412).
303+
pub fn testing_operator() -> Operator {
304+
let (repo_id, token) = testing_credentials();
305+
let op = Operator::new(
306+
HfBuilder::default()
307+
.repo_type("dataset")
308+
.repo_id(&repo_id)
309+
.token(&token)
310+
.max_retries(10),
311+
)
312+
.unwrap()
313+
.finish();
314+
finish_operator(op)
315+
}
316+
317+
#[cfg(feature = "xet")]
318+
pub fn testing_xet_operator() -> Operator {
319+
let (repo_id, token) = testing_credentials();
320+
let op = Operator::new(
321+
HfBuilder::default()
322+
.repo_type("dataset")
323+
.repo_id(&repo_id)
324+
.token(&token)
325+
.enable_xet()
326+
.max_retries(10),
327+
)
328+
.unwrap()
329+
.finish();
330+
finish_operator(op)
331+
}
332+
333+
pub fn gpt2_operator() -> Operator {
334+
let op = Operator::new(
335+
HfBuilder::default()
336+
.repo_type("model")
337+
.repo_id("openai-community/gpt2"),
338+
)
339+
.unwrap()
340+
.finish();
341+
finish_operator(op)
342+
}
343+
344+
pub fn mbpp_operator() -> Operator {
345+
let op = Operator::new(
346+
HfBuilder::default()
347+
.repo_type("dataset")
348+
.repo_id("google-research-datasets/mbpp"),
349+
)
350+
.unwrap()
351+
.finish();
352+
finish_operator(op)
353+
}
354+
355+
#[cfg(feature = "xet")]
356+
pub fn mbpp_xet_operator() -> Operator {
357+
let mut builder = HfBuilder::default()
358+
.repo_type("dataset")
359+
.repo_id("google-research-datasets/mbpp")
360+
.enable_xet();
361+
if let Ok(token) = std::env::var("HF_OPENDAL_TOKEN") {
362+
builder = builder.token(&token);
363+
}
364+
let op = Operator::new(builder).unwrap().finish();
365+
finish_operator(op)
366+
}
367+
}
368+
269369
#[cfg(test)]
270370
mod tests {
371+
use super::test_utils::mbpp_operator;
372+
#[cfg(feature = "xet")]
373+
use super::test_utils::mbpp_xet_operator;
271374
use super::*;
272375

273376
#[test]
@@ -311,16 +414,6 @@ mod tests {
311414
/// Parquet magic bytes: "PAR1"
312415
const PARQUET_MAGIC: &[u8] = b"PAR1";
313416

314-
fn mbpp_operator() -> Operator {
315-
let builder = HfBuilder::default()
316-
.repo_type("dataset")
317-
.repo_id("google-research-datasets/mbpp")
318-
.revision("main")
319-
.root("/");
320-
321-
Operator::new(builder).unwrap().finish()
322-
}
323-
324417
#[tokio::test]
325418
#[ignore = "requires network access"]
326419
async fn test_read_parquet_http() {
@@ -348,29 +441,11 @@ mod tests {
348441
assert_eq!(&footer.to_vec(), PARQUET_MAGIC);
349442
}
350443

351-
#[cfg(feature = "xet")]
352-
fn mbpp_operator_xet() -> Operator {
353-
let repo_id = std::env::var("HF_OPENDAL_DATASET")
354-
.unwrap_or_else(|_| "google-research-datasets/mbpp".to_string());
355-
let mut builder = HfBuilder::default()
356-
.repo_type("dataset")
357-
.repo_id(&repo_id)
358-
.revision("main")
359-
.root("/")
360-
.xet(true);
361-
362-
if let Ok(token) = std::env::var("HF_OPENDAL_TOKEN") {
363-
builder = builder.token(&token);
364-
}
365-
366-
Operator::new(builder).unwrap().finish()
367-
}
368-
369444
#[cfg(feature = "xet")]
370445
#[tokio::test]
371446
#[ignore = "requires network access"]
372447
async fn test_read_parquet_xet() {
373-
let op = mbpp_operator_xet();
448+
let op = mbpp_xet_operator();
374449
let path = "full/train-00000-of-00001.parquet";
375450

376451
// Full read via XET and verify parquet magic at both ends
@@ -380,4 +455,103 @@ mod tests {
380455
assert_eq!(&bytes[..4], PARQUET_MAGIC);
381456
assert_eq!(&bytes[bytes.len() - 4..], PARQUET_MAGIC);
382457
}
458+
459+
/// List files in a known dataset directory.
460+
#[tokio::test]
461+
#[ignore = "requires network access"]
462+
async fn test_list_directory() {
463+
let op = mbpp_operator();
464+
let entries = op.list("full/").await.expect("list should succeed");
465+
assert!(!entries.is_empty(), "directory should contain files");
466+
assert!(
467+
entries.iter().any(|e| e.path().ends_with(".parquet")),
468+
"should contain parquet files"
469+
);
470+
}
471+
472+
/// List files recursively from root.
473+
#[tokio::test]
474+
#[ignore = "requires network access"]
475+
async fn test_list_recursive() {
476+
let op = mbpp_operator();
477+
let entries = op
478+
.list_with("/")
479+
.recursive(true)
480+
.await
481+
.expect("recursive list should succeed");
482+
assert!(
483+
entries.len() > 1,
484+
"recursive listing should find multiple files"
485+
);
486+
}
487+
488+
/// Stat a known file and verify metadata fields.
489+
#[tokio::test]
490+
#[ignore = "requires network access"]
491+
async fn test_stat_known_file() {
492+
let op = mbpp_operator();
493+
let meta = op
494+
.stat("full/train-00000-of-00001.parquet")
495+
.await
496+
.expect("stat should succeed");
497+
assert!(meta.content_length() > 0);
498+
assert!(!meta.etag().unwrap_or_default().is_empty());
499+
}
500+
501+
/// Stat a nonexistent path should return NotFound.
502+
#[tokio::test]
503+
#[ignore = "requires network access"]
504+
async fn test_stat_nonexistent() {
505+
let op = mbpp_operator();
506+
let err = op
507+
.stat("this/path/does/not/exist.txt")
508+
.await
509+
.expect_err("stat on nonexistent path should fail");
510+
assert_eq!(err.kind(), ErrorKind::NotFound);
511+
}
512+
513+
/// Read a nonexistent file should return NotFound.
514+
#[tokio::test]
515+
#[ignore = "requires network access"]
516+
async fn test_read_nonexistent() {
517+
let op = mbpp_operator();
518+
let err = op
519+
.read("this/path/does/not/exist.txt")
520+
.await
521+
.expect_err("read on nonexistent path should fail");
522+
assert_eq!(err.kind(), ErrorKind::NotFound);
523+
}
524+
525+
/// Read a middle range of a known file.
526+
#[tokio::test]
527+
#[ignore = "requires network access"]
528+
async fn test_read_range_middle() {
529+
let op = mbpp_operator();
530+
let data = op
531+
.read_with("full/train-00000-of-00001.parquet")
532+
.range(100..200)
533+
.await
534+
.expect("range read should succeed");
535+
assert_eq!(data.to_bytes().len(), 100);
536+
}
537+
538+
/// Read the last N bytes of a file to exercise tail-range handling.
539+
#[tokio::test]
540+
#[ignore = "requires network access"]
541+
async fn test_read_range_tail() {
542+
let op = mbpp_operator();
543+
let path = "full/train-00000-of-00001.parquet";
544+
let meta = op.stat(path).await.expect("stat should succeed");
545+
let size = meta.content_length();
546+
547+
let data = op
548+
.read_with(path)
549+
.range(size - 100..size)
550+
.await
551+
.expect("tail range read should succeed");
552+
let bytes = data.to_bytes();
553+
assert_eq!(bytes.len(), 100);
554+
// Parquet files end with "PAR1" magic
555+
assert_eq!(&bytes[bytes.len() - 4..], PARQUET_MAGIC);
556+
}
383557
}

0 commit comments

Comments
 (0)