Skip to content

Commit a4a02cf

Browse files
committed
feat(data): introduce a streaming client
1 parent 91f3efe commit a4a02cf

File tree

11 files changed

+605
-346
lines changed

11 files changed

+605
-346
lines changed

Cargo.lock

Lines changed: 2 additions & 23 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ debug = 1
3737
[workspace.dependencies]
3838
anyhow = "1"
3939
axum = "0.8"
40-
async-stream = "0.3"
4140
async-trait = "0.1"
4241
base64 = "0.22"
42+
futures-core = "0.3"
4343
bincode = "1.3"
4444
bitflags = { version = "2.9", features = ["serde"] }
4545
blake3 = "1.5"

data/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ xet_runtime = { path = "../xet_runtime" }
4242

4343
futures = { workspace = true }
4444
anyhow = { workspace = true }
45-
async-stream = { workspace = true }
4645
async-trait = { workspace = true }
4746
bytes = { workspace = true }
4847
chrono = { workspace = true }

data/src/data_client.rs

Lines changed: 5 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@ use std::io::Read;
33
use std::path::{Path, PathBuf};
44
use std::sync::Arc;
55

6-
use async_stream::try_stream;
76
use bytes::Bytes;
87
use cas_client::remote_client::PREFIX_DEFAULT;
98
use cas_object::CompressionScheme;
10-
use cas_types::FileRange;
119
use deduplication::{Chunker, DeduplicationMetrics};
1210
use file_reconstruction::DataOutput;
13-
use futures::Stream;
1411
use lazy_static::lazy_static;
1512
use mdb_shard::Sha256;
1613
use merklehash::MerkleHash;
@@ -136,75 +133,6 @@ pub async fn upload_bytes_async(
136133
Ok(files)
137134
}
138135

139-
/// Downloads multiple files and returns their contents as byte streams.
140-
///
141-
/// Returns one stream per file. Each stream is lazy — the download starts only
142-
/// when the stream is first polled. The number of concurrent downloads is
143-
/// bounded by a global semaphore. `stream_buffer_size` controls how many chunks
144-
/// can be buffered in each stream before backpressure is applied.
145-
///
146-
/// **Error handling:** If a download fails after some chunks have already been
147-
/// yielded, the error surfaces only when the stream is polled past the last
148-
/// buffered chunk. Callers must consume the stream to completion (or until an
149-
/// error is returned) to detect download failures.
150-
#[allow(clippy::too_many_arguments)]
151-
#[instrument(skip_all, name = "data_client::download_bytes", fields(session_id = tracing::field::Empty, num_files=file_infos.len()))]
152-
pub async fn download_bytes_async(
153-
file_infos: Vec<XetFileInfo>,
154-
file_ranges: Option<Vec<Option<FileRange>>>,
155-
endpoint: Option<String>,
156-
token_info: Option<(String, u64)>,
157-
token_refresher: Option<Arc<dyn TokenRefresher>>,
158-
progress_updaters: Option<Vec<Arc<dyn TrackingProgressUpdater>>>,
159-
user_agent: String,
160-
stream_buffer_size: usize,
161-
) -> errors::Result<Vec<impl Stream<Item = errors::Result<Bytes>>>> {
162-
if let Some(updaters) = &progress_updaters
163-
&& updaters.len() != file_infos.len()
164-
{
165-
return Err(DataProcessingError::ParameterError("updaters are not same length as file_infos".to_string()));
166-
}
167-
if let Some(ranges) = &file_ranges
168-
&& ranges.len() != file_infos.len()
169-
{
170-
return Err(DataProcessingError::ParameterError("file_ranges are not same length as file_infos".to_string()));
171-
}
172-
173-
let config = default_config(
174-
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
175-
None,
176-
token_info,
177-
token_refresher,
178-
user_agent,
179-
)?;
180-
Span::current().record("session_id", &config.session_id);
181-
182-
let downloader = Arc::new(FileDownloader::new(config.into()).await?);
183-
let semaphore = XetRuntime::current().global_semaphore(*CONCURRENT_FILE_DOWNLOAD_LIMITER);
184-
let updaters = match progress_updaters {
185-
None => vec![None; file_infos.len()],
186-
Some(updaters) => updaters.into_iter().map(Some).collect(),
187-
};
188-
let ranges = match file_ranges {
189-
None => vec![None; file_infos.len()],
190-
Some(ranges) => ranges,
191-
};
192-
193-
let mut readers = Vec::with_capacity(updaters.len());
194-
for ((file_info, file_range), updater) in file_infos.into_iter().zip(ranges).zip(updaters) {
195-
readers.push(smudge_bytes(
196-
downloader.clone(),
197-
semaphore.clone(),
198-
file_info,
199-
file_range,
200-
updater,
201-
stream_buffer_size,
202-
)?);
203-
}
204-
205-
Ok(readers)
206-
}
207-
208136
// The sha256, if provided and valid, will be directly used in shard upload to avoid redundant computation.
209137
#[instrument(skip_all, name = "data_client::upload_files",
210138
fields(session_id = tracing::field::Empty,
@@ -480,50 +408,9 @@ async fn smudge_file(
480408
Ok(file_path.to_string())
481409
}
482410

483-
/// Downloads a file's content and returns it as a stream of byte chunks.
484-
///
485-
/// The download is lazy: nothing happens until the consumer starts polling the
486-
/// stream. On first poll, a `StreamingWriter` is created that sends resolved
487-
/// `Bytes` chunks directly through an async channel, avoiding the blocking
488-
/// thread and redundant copies of `SequentialWriter`. Backpressure is provided
489-
/// by the bounded channel buffer. A semaphore limits the number of concurrent
490-
/// downloads.
491-
fn smudge_bytes(
492-
downloader: Arc<FileDownloader>,
493-
semaphore: Arc<tokio::sync::Semaphore>,
494-
file_info: XetFileInfo,
495-
file_range: Option<FileRange>,
496-
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
497-
stream_buffer_size: usize,
498-
) -> errors::Result<impl Stream<Item = errors::Result<Bytes>>> {
499-
let merkle_hash = file_info.merkle_hash()?;
500-
let file_hash = file_info.hash().into();
501-
502-
Ok(try_stream! {
503-
let progress_updater = progress_updater.map(ItemProgressUpdater::new);
504-
let (output, mut rx) = DataOutput::write_byte_stream(stream_buffer_size);
505-
506-
let handle = tokio::spawn(async move {
507-
let _permit = semaphore.acquire().await.map_err(|_| {
508-
DataProcessingError::InternalError("download semaphore closed".to_string())
509-
})?;
510-
downloader
511-
.smudge_file_from_hash(&merkle_hash, file_hash, output, file_range, progress_updater)
512-
.await
513-
});
514-
515-
while let Some(chunk) = rx.recv().await {
516-
yield chunk;
517-
}
518-
519-
handle.await??;
520-
})
521-
}
522-
523411
#[cfg(test)]
524412
mod tests {
525413
use dirs::home_dir;
526-
use futures::TryStreamExt;
527414
use serial_test::serial;
528415
use tempfile::tempdir;
529416
use utils::EnvVarGuard;
@@ -738,39 +625,16 @@ mod tests {
738625
assert_eq!(file_info1.file_size(), file_info3.file_size());
739626
}
740627

741-
fn test_contents() -> Vec<Vec<u8>> {
742-
vec![
743-
vec![],
744-
b"Hello, World!".to_vec(),
745-
(0..1_000_000).map(|i| (i % 256) as u8).collect(),
746-
]
747-
}
748-
749-
#[tokio::test]
750-
async fn test_bytes_round_trip() {
751-
let temp_dir = tempdir().unwrap();
752-
let endpoint = format!("local://{}", temp_dir.path().display());
753-
let contents = test_contents();
754-
755-
let file_infos = upload_bytes_async(contents.clone(), Some(endpoint.clone()), None, None, None, "test".into())
756-
.await
757-
.unwrap();
758-
759-
let readers = download_bytes_async(file_infos, None, Some(endpoint), None, None, None, "test".into(), 64)
760-
.await
761-
.unwrap();
762-
for (reader, expected) in readers.into_iter().zip(&contents) {
763-
let chunks: Vec<Bytes> = reader.try_collect().await.unwrap();
764-
assert_eq!(chunks.concat(), *expected);
765-
}
766-
}
767-
768628
#[tokio::test]
769629
async fn test_upload_bytes() {
770630
let temp_dir = tempdir().unwrap();
771631
let endpoint = format!("local://{}", temp_dir.path().display());
772632

773-
let contents = test_contents();
633+
let contents: Vec<Vec<u8>> = vec![
634+
vec![],
635+
b"Hello, World!".to_vec(),
636+
(0..1_000_000).map(|i| (i % 256) as u8).collect(),
637+
];
774638

775639
// Upload all as bytes.
776640
let file_infos = upload_bytes_async(contents.clone(), Some(endpoint.clone()), None, None, None, "test".into())
@@ -799,36 +663,4 @@ mod tests {
799663
assert_eq!(std::fs::read(path).unwrap(), *expected);
800664
}
801665
}
802-
803-
#[tokio::test]
804-
async fn test_download_bytes() {
805-
let temp_dir = tempdir().unwrap();
806-
let endpoint = format!("local://{}", temp_dir.path().display());
807-
808-
let contents = test_contents();
809-
810-
// Write contents to files and upload via upload_async.
811-
let upload_dir = tempdir().unwrap();
812-
let file_paths: Vec<String> = contents
813-
.iter()
814-
.enumerate()
815-
.map(|(i, content)| {
816-
let path = upload_dir.path().join(format!("{i}"));
817-
std::fs::write(&path, content).unwrap();
818-
path.to_str().unwrap().to_string()
819-
})
820-
.collect();
821-
let file_infos = upload_async(file_paths, None, Some(endpoint.clone()), None, None, None, "test".into())
822-
.await
823-
.unwrap();
824-
825-
// Download all as bytes and verify.
826-
let readers = download_bytes_async(file_infos, None, Some(endpoint), None, None, None, "test".into(), 64)
827-
.await
828-
.unwrap();
829-
for (reader, expected) in readers.into_iter().zip(&contents) {
830-
let chunks: Vec<Bytes> = reader.try_collect().await.unwrap();
831-
assert_eq!(chunks.concat(), *expected);
832-
}
833-
}
834666
}

data/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod prometheus_metrics;
1010
mod remote_client_interface;
1111
mod sha256;
1212
mod shard_interface;
13+
pub mod streaming;
1314
mod xet_file;
1415

1516
// Reexport this one for now

0 commit comments

Comments
 (0)