-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathxtool.rs
More file actions
292 lines (259 loc) · 9.96 KB
/
xtool.rs
File metadata and controls
292 lines (259 loc) · 9.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use cas_client::RemoteClient;
use cas_object::CompressionScheme;
use cas_types::{FileRange, QueryReconstructionResponse};
use clap::{Args, Parser, Subcommand};
use data::data_client::default_config;
use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
use data::migration_tool::migrate::migrate_files_impl;
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use humansize::{BINARY, DECIMAL, format_size};
use merklehash::MerkleHash;
use utils::auth::TokenRefresher;
use walkdir::WalkDir;
use xet_runtime::XetRuntime;
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
const USER_AGENT: &str = concat!("xtool", "/", env!("CARGO_PKG_VERSION"));
#[derive(Parser)]
struct XCommand {
#[clap(flatten)]
overrides: CliOverrides,
#[clap(subcommand)]
command: Command,
}
#[derive(Args)]
struct CliOverrides {
/// HF Hub endpoint.
#[clap(long)]
endpoint: Option<String>, // if not specified we use env:HF_ENDPOINT
/// HF Hub access token.
#[clap(long)]
token: Option<String>, // if not specified we use env:HF_TOKEN
/// Type of the associated repo: "model", "dataset", or "space"
#[clap(long)]
repo_type: String,
/// A namespace and a repo name separated by a '/'.
#[clap(long)]
repo_id: String,
}
impl XCommand {
async fn run(self) -> Result<()> {
let endpoint = self
.overrides
.endpoint
.unwrap_or_else(|| std::env::var("HF_ENDPOINT").unwrap_or(DEFAULT_HF_ENDPOINT.to_owned()));
let token = self
.overrides
.token
.unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default());
let cred_helper = BearerCredentialHelper::new(token, "");
let hub_client = HubClient::new(
&endpoint,
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
Some("main".to_owned()),
USER_AGENT,
"",
cred_helper,
)?;
self.command.run(hub_client).await
}
}
#[derive(Subcommand)]
enum Command {
/// Dry-run of file upload to get file info after dedup.
Dedup(DedupArg),
/// Queries reconstruction information about a file.
Query(QueryArg),
/// Calculates the compressed size of a xet-file by summing url_range sizes.
CompressedSize(CompressedSizeArg),
}
#[derive(Args)]
struct DedupArg {
/// Path to the file to dedup.
files: Vec<String>,
/// If the paths specified are directories, compute recursively for files
/// under these directories.
#[clap(short, long)]
recursive: bool,
/// Compute for files sequentially in the order as specified, or as enumerated
/// from directory walking if in recursive mode. This can be helpful to study
/// a set of files where there is a temporal relation.
#[clap(short, long)]
sequential: bool,
/// If a file path is specified, write out the JSON formatted file reconstruction info
/// to the file; otherwise write out to the stdout.
#[clap(short, long)]
output: Option<PathBuf>,
/// The compression scheme to use on XORB upload. Choices are
/// 0: no compression;
/// 1: LZ4 compression;
/// 2: 4 byte groups with LZ4 compression.
/// If not specified, this will be determined by the repo type.
#[clap(short, long)]
compression: Option<u8>,
/// Migrate the files by actually uploading them to the CAS server.
#[clap(short, long)]
migrate: bool,
}
#[derive(Args)]
struct QueryArg {
/// Xet-hash of a file.
hash: String,
/// Query regarding a certain range in bytes: [start, end), specified
/// in the format of "start-end".
bytes_range: Option<FileRange>,
}
#[derive(Args)]
struct CompressedSizeArg {
/// Xet-hash of a file.
hash: String,
}
impl Command {
async fn run(self, hub_client: HubClient) -> Result<()> {
match self {
Command::Dedup(arg) => {
let file_paths = walk_files(arg.files, arg.recursive);
eprintln!("Dedupping {} files...", file_paths.len());
let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
file_paths,
arg.sequential,
hub_client,
None,
arg.compression.and_then(|c| CompressionScheme::try_from(c).ok()),
!arg.migrate,
)
.await?;
// Print file info for analysis
if !arg.migrate {
let mut writer: Box<dyn Write> = if let Some(path) = arg.output {
Box::new(BufWriter::new(File::options().create(true).write(true).truncate(true).open(path)?))
} else {
Box::new(std::io::stdout())
};
serde_json::to_writer(&mut writer, &all_file_info)?;
writer.flush()?;
}
eprintln!("\n\nClean results:");
for (xf, new_bytes) in clean_ret {
println!("{}: {} bytes -> {} bytes", xf.hash(), xf.file_size(), new_bytes);
}
eprintln!("Transmitted {total_bytes_trans} bytes in total.");
Ok(())
},
Command::Query(arg) => {
let file_hash = MerkleHash::from_hex(&arg.hash)?;
let ret = query_reconstruction(file_hash, arg.bytes_range, hub_client).await?;
eprintln!("{ret:?}");
Ok(())
},
Command::CompressedSize(arg) => {
let file_hash = MerkleHash::from_hex(&arg.hash)?;
// Query reconstruction for full file (no Range header)
let ret = query_reconstruction(file_hash, None, hub_client).await?;
match ret {
Some(response) => {
let mut total_compressed_size = 0u64;
for fetch_infos in response.fetch_info.values() {
for fetch_info in fetch_infos {
let range_size = fetch_info.url_range.end - fetch_info.url_range.start;
total_compressed_size += range_size;
}
}
let total_uncompressed_size: u64 =
response.terms.iter().map(|term| term.unpacked_length as u64).sum();
// Count unique XORBs
let unique_xorbs: std::collections::HashSet<_> =
response.terms.iter().map(|term| &term.hash).collect();
println!("Compressed Size: {}", format_bytes_with_units(total_compressed_size));
println!("Uncompressed Size: {}", format_bytes_with_units(total_uncompressed_size));
println!(
"Compression Ratio: {:.2}%",
(total_compressed_size as f64 / total_uncompressed_size as f64) * 100.0
);
println!("XORBs: {} unique", unique_xorbs.len());
Ok(())
},
None => {
eprintln!("No reconstruction information found for hash {}", arg.hash);
Ok(())
},
}
},
}
}
}
fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
// Scan all files if under recursive mode
if recursive {
files
.iter()
.flat_map(|dir| {
WalkDir::new(dir)
.follow_links(false)
.max_depth(usize::MAX)
.into_iter()
.filter_entry(|e| !is_git_special_files(e.file_name().to_str().unwrap_or_default()))
.flatten()
.filter(|e| {
e.file_type().is_file() && !is_git_special_files(e.file_name().to_str().unwrap_or_default())
})
.filter_map(|e| e.path().to_str().map(|s| s.to_owned()))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
} else {
files
}
}
fn is_git_special_files(path: &str) -> bool {
matches!(path, ".git" | ".gitignore" | ".gitattributes")
}
/// Format bytes with binary and decimal units on one line
fn format_bytes_with_units(bytes: u64) -> String {
let binary = format_size(bytes, BINARY);
let decimal = format_size(bytes, DECIMAL);
format!("{} bytes {} {}", bytes, binary, decimal)
}
async fn query_reconstruction(
file_hash: MerkleHash,
bytes_range: Option<FileRange>,
hub_client: HubClient,
) -> Result<Option<QueryReconstructionResponse>> {
let operation = Operation::Download;
let jwt_info = hub_client.get_cas_jwt(operation).await?;
let token_refresher = Arc::new(HubClientTokenRefresher {
operation,
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;
let config = default_config(
jwt_info.cas_url.clone(),
None,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
USER_AGENT.to_string(),
)?;
let cas_storage_config = &config.data_config;
let remote_client = RemoteClient::new(
&jwt_info.cas_url,
&cas_storage_config.auth,
&Some(cas_storage_config.cache_config.clone()),
Some(config.shard_config.cache_directory.clone()),
"",
true,
&cas_storage_config.user_agent,
);
remote_client
.get_reconstruction(&file_hash, bytes_range)
.await
.map_err(anyhow::Error::from)
}
fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = XetRuntime::new()?;
threadpool.external_run_async_task(async move { cli.run().await })??;
Ok(())
}