Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ sha2 = "0.10"
shell-words = "1.1"
shellexpand = "3.1"
static_assertions = "1.1"
strum = { version = "0.26", features = ["derive"] }
statrs = "0.16"
sysinfo = "0.37"
tempfile = "3.20"
Expand Down
2 changes: 1 addition & 1 deletion cas_client/src/client_testing_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ pub trait ClientTestingUtils: Client + Send + Sync {

shard.add_cas_block(raw_xorb.cas_info.clone())?;

let serialized_xorb = SerializedCasObject::from_xorb(raw_xorb.clone(), None, true)?;
let serialized_xorb = SerializedCasObject::from_xorb(raw_xorb.clone(), true)?;

let upload_permit = self.acquire_upload_permit().await?;
self.upload_xorb("default", serialized_xorb, None, upload_permit).await?;
Expand Down
6 changes: 3 additions & 3 deletions cas_client/src/download_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ mod tests {
then.status(200).json_body_obj(&response);
});

let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, &None, "", false, "");
let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, None, "", false, "");
let fetch_info = FetchInfo::new(MerkleHash::default(), file_range);

fetch_info.query(&client).await?;
Expand Down Expand Up @@ -667,7 +667,7 @@ mod tests {
then.status(200).json_body_obj(&response);
});

let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, &None, "", false, "");
let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, None, "", false, "");
let fetch_info = Arc::new(FetchInfo::new(MerkleHash::default(), file_range_to_refresh));

// Spawn multiple tasks each calling into refresh with a different delay in
Expand Down Expand Up @@ -733,7 +733,7 @@ mod tests {
then.status(403).delay(Duration::from_millis(100));
});

let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, &None, "", false, "");
let client: Arc<dyn Client + Send + Sync> = RemoteClient::new(&server.base_url(), &None, None, "", false, "");

let fetch_info = FetchInfo::new(MerkleHash::default(), file_range);

Expand Down
1 change: 0 additions & 1 deletion cas_client/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub use chunk_cache::CacheConfig;
pub use http_client::{Api, ResponseErrorLogger, RetryConfig, build_auth_http_client, build_http_client};
pub use interface::Client;
#[cfg(not(target_family = "wasm"))]
Expand Down
2 changes: 1 addition & 1 deletion cas_client/src/local_server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl LocalTestServer {

tokio::time::sleep(Duration::from_millis(50)).await;

let remote_client = RemoteClient::new(&endpoint, &None, &None, "test-session", false, "test-agent");
let remote_client = RemoteClient::new(&endpoint, &None, None, "test-session", false, "test-agent");

Self {
endpoint,
Expand Down
22 changes: 12 additions & 10 deletions cas_client/src/remote_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::mem::take;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};

Expand All @@ -10,7 +11,7 @@ use cas_types::{
BatchQueryReconstructionResponse, CASReconstructionFetchInfo, CASReconstructionTerm, ChunkRange, FileRange,
HttpRange, Key, QueryReconstructionResponse, UploadShardResponse, UploadShardResponseType, UploadXorbResponse,
};
use chunk_cache::{CacheConfig, ChunkCache};
use chunk_cache::ChunkCache;
use error_printer::ErrorPrinter;
use http::HeaderValue;
use http::header::{CONTENT_LENGTH, RANGE};
Expand Down Expand Up @@ -184,19 +185,20 @@ impl RemoteClient {
pub fn new(
endpoint: &str,
auth: &Option<AuthConfig>,
cache_config: &Option<CacheConfig>,
cache_directory: Option<&PathBuf>,
session_id: &str,
dry_run: bool,
user_agent: &str,
) -> Arc<Self> {
// use disk cache if cache_config provided.
let chunk_cache = if let Some(cache_config) = cache_config {
if cache_config.cache_size == 0 {
// use disk cache if cache directory provided and cache size > 0.
let chunk_cache = if let Some(cache_dir) = cache_directory {
let cache_size = xet_config().chunk_cache.size_bytes;
if cache_size == 0 {
event!(INFORMATION_LOG_LEVEL, "Chunk cache size set to 0, disabling chunk cache");
None
} else {
event!(INFORMATION_LOG_LEVEL, cache.dir=?cache_config.cache_directory, cache.size=cache_config.cache_size,"Using disk cache");
chunk_cache::get_cache(cache_config)
event!(INFORMATION_LOG_LEVEL, cache.dir=?cache_dir, cache.size=cache_size, "Using disk cache");
chunk_cache::get_cache(cache_dir)
.log_error("failed to initialize cache, not using cache")
.ok()
}
Expand Down Expand Up @@ -1009,7 +1011,7 @@ mod tests {
let raw_xorb = build_raw_xorb(3, ChunkSize::Random(512, 10248));

let threadpool = XetRuntime::new().unwrap();
let client = RemoteClient::new(CAS_ENDPOINT, &None, &None, "", false, "");
let client = RemoteClient::new(CAS_ENDPOINT, &None, None, "", false, "");

let cas_object = build_and_verify_cas_object(raw_xorb, Some(CompressionScheme::LZ4));

Expand Down Expand Up @@ -1392,7 +1394,7 @@ mod tests {

// test reconstruct and sequential write
let test = test_case.clone();
let client = RemoteClient::new(endpoint, &None, &None, "", false, "");
let client = RemoteClient::new(endpoint, &None, None, "", false, "");
let buf = ThreadSafeBuffer::default();
let provider = SequentialOutput::from(buf.clone());
let resp = threadpool.external_run_async_task(async move {
Expand All @@ -1414,7 +1416,7 @@ mod tests {

// test reconstruct and parallel write
let test = test_case;
let client = RemoteClient::new(endpoint, &None, &None, "", false, "");
let client = RemoteClient::new(endpoint, &None, None, "", false, "");
let buf = ThreadSafeBuffer::default();
let provider = SeekingOutputProvider::from(buf.clone());
let resp = threadpool.external_run_async_task(async move {
Expand Down
1 change: 1 addition & 0 deletions cas_object/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ lz4_flex = { workspace = true }
more-asserts = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["time", "rt", "macros", "io-util"] }
tracing = { workspace = true }
Expand Down
15 changes: 13 additions & 2 deletions cas_object/src/cas_object_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,17 @@ pub struct SerializedCasObject {

impl SerializedCasObject {
/// Builds the xorb from raw xorb data.
pub fn from_xorb(
/// Compression is determined by the `HF_XET_XORB_COMPRESSION_POLICY` environment variable.
/// If empty or "auto", the best compression is chosen based on data analysis.
#[cfg(not(target_family = "wasm"))]
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, CasObjectError> {
let compression_scheme = CompressionScheme::parse_optional(&xet_config().xorb.compression_policy);
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer)
}

/// Builds the xorb from raw xorb data with an explicit compression scheme.
/// If compression_scheme is None, the best compression is chosen based on data analysis.
pub fn from_xorb_with_compression(
xorb: RawXorbData,
compression_scheme: Option<CompressionScheme>,
serialize_footer: bool,
Expand Down Expand Up @@ -1561,7 +1571,8 @@ pub mod test_utils {
xorb: RawXorbData,
compression_scheme: Option<CompressionScheme>,
) -> SerializedCasObject {
let cas_object = SerializedCasObject::from_xorb(xorb.clone(), compression_scheme, true).unwrap();
let cas_object =
SerializedCasObject::from_xorb_with_compression(xorb.clone(), compression_scheme, true).unwrap();

verify_serialized_cas_object(&xorb, compression_scheme, &cas_object);

Expand Down
60 changes: 55 additions & 5 deletions cas_object/src/compression_scheme.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::borrow::Cow;
use std::fmt::Display;
use std::io::{Cursor, Read, Write, copy};
use std::str::FromStr;
use std::time::Instant;

use anyhow::anyhow;
use lz4_flex::frame::{FrameDecoder, FrameEncoder};
use strum::{Display, EnumString};

use crate::byte_grouping::BG4Predictor;
use crate::byte_grouping::bg4::{bg4_regroup, bg4_split};
Expand All @@ -15,22 +16,47 @@ pub static mut BG4_REGROUP_RUNTIME: f64 = 0.;
pub static mut BG4_LZ4_COMPRESS_RUNTIME: f64 = 0.;
pub static mut BG4_LZ4_DECOMPRESS_RUNTIME: f64 = 0.;

/// Compression schemes for xorb data.
/// Dis-allow the value of ascii capital letters as valid CompressionScheme, 65-90
#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default, Display, EnumString)]
#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
pub enum CompressionScheme {
#[default]
#[strum(serialize = "none")]
None = 0,

#[strum(serialize = "lz4")]
LZ4 = 1,

#[strum(serialize = "bg4-lz4", serialize = "bg4_lz4", serialize = "bg4lz4")]
ByteGrouping4LZ4 = 2, // 4 byte groups
}
pub const NUM_COMPRESSION_SCHEMES: usize = 3;

impl Display for CompressionScheme {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Into::<&str>::into(self))
impl CompressionScheme {
/// Parses an optional compression scheme from a string.
/// Returns None for empty/blank strings (meaning auto-detect).
/// Returns Some(scheme) for valid scheme names.
/// Returns None and logs a warning for invalid scheme names.
pub fn parse_optional(s: &str) -> Option<Self> {
let trimmed = s.trim();
if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("auto") {
return None;
}
match Self::from_str(trimmed) {
Ok(scheme) => Some(scheme),
Err(_) => {
tracing::warn!(
"Invalid compression scheme '{}'; using auto-detection. Valid values: none, lz4, bg4-lz4",
trimmed
);
None
},
}
}
}

impl From<&CompressionScheme> for &'static str {
fn from(value: &CompressionScheme) -> Self {
match value {
Expand Down Expand Up @@ -163,6 +189,7 @@ fn bg4_lz4_decompress_from_reader<R: Read, W: Write>(reader: &mut R, writer: &mu
#[cfg(test)]
mod tests {
use std::mem::size_of;
use std::str::FromStr;

use half::prelude::*;
use rand::Rng;
Expand All @@ -176,6 +203,29 @@ mod tests {
assert_eq!(Into::<&str>::into(CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
}

#[test]
fn test_from_str() {
assert_eq!(CompressionScheme::from_str("none"), Ok(CompressionScheme::None));
assert_eq!(CompressionScheme::from_str("lz4"), Ok(CompressionScheme::LZ4));
assert_eq!(CompressionScheme::from_str("LZ4"), Ok(CompressionScheme::LZ4));
assert_eq!(CompressionScheme::from_str("bg4-lz4"), Ok(CompressionScheme::ByteGrouping4LZ4));
assert_eq!(CompressionScheme::from_str("bg4_lz4"), Ok(CompressionScheme::ByteGrouping4LZ4));
assert_eq!(CompressionScheme::from_str("BG4-LZ4"), Ok(CompressionScheme::ByteGrouping4LZ4));
assert!(CompressionScheme::from_str("invalid").is_err());
}

#[test]
fn test_parse_optional() {
assert_eq!(CompressionScheme::parse_optional(""), None);
assert_eq!(CompressionScheme::parse_optional("auto"), None);
assert_eq!(CompressionScheme::parse_optional("AUTO"), None);
assert_eq!(CompressionScheme::parse_optional(" "), None);
assert_eq!(CompressionScheme::parse_optional("lz4"), Some(CompressionScheme::LZ4));
assert_eq!(CompressionScheme::parse_optional("none"), Some(CompressionScheme::None));
assert_eq!(CompressionScheme::parse_optional("bg4-lz4"), Some(CompressionScheme::ByteGrouping4LZ4));
assert_eq!(CompressionScheme::parse_optional("invalid"), None);
}

#[test]
fn test_from_u8() {
assert_eq!(CompressionScheme::try_from(0u8), Ok(CompressionScheme::None));
Expand Down
8 changes: 2 additions & 6 deletions chunk_cache/src/bin/analysis.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::PathBuf;
use std::u64;

use chunk_cache::{CacheConfig, DiskCache};
use chunk_cache::DiskCache;
use clap::Parser;

#[derive(Debug, Parser)]
Expand All @@ -18,10 +18,6 @@ fn main() {
}

fn print_main(root: PathBuf) {
let cache = DiskCache::initialize(&CacheConfig {
cache_directory: root,
cache_size: u64::MAX,
})
.unwrap();
let cache = DiskCache::initialize_with_capacity(&root, u64::MAX).unwrap();
cache.print();
}
Loading
Loading