Skip to content

Commit 4601a3d

Browse files
authored
Merge pull request #13 from sangshuduo/feat/sangshuduo/s3-bucket-downloader
feat(s3_bucket_downloader): add S3 bucket downloader tool
2 parents 7e373e5 + f1784e2 commit 4601a3d

File tree

3 files changed

+310
-0
lines changed

3 files changed

+310
-0
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ members = [
1111
"random_pairs_of_s3file",
1212
"find_log_processtime",
1313
"archive_dirs",
14+
"s3_bucket_downloader",
1415
# Add other tools here
1516
]
1617
resolver = "2" # Add this line to specify resolver version 2

s3_bucket_downloader/Cargo.toml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "s3_bucket_downloader"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
aws-config = "1.6.1"
8+
aws-sdk-s3 = "1.82.0"
9+
tokio = { version = "1.44.1", features = ["full"] }
10+
rayon = "1.10"
11+
futures = "0.3.31"
12+
indicatif = "0.17.11"
13+
clap = { version = "4.5.34", features = ["derive"] }

s3_bucket_downloader/src/main.rs

+296
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
use aws_config::meta::region::RegionProviderChain;
2+
use aws_config::retry::RetryConfig;
3+
use aws_sdk_s3::primitives::ByteStream;
4+
use aws_sdk_s3::Client;
5+
use clap::Parser;
6+
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
7+
use rayon::prelude::*;
8+
use std::fs::{self, File};
9+
use std::io::Result;
10+
use std::io::{BufRead, BufReader, Write};
11+
use std::path::PathBuf;
12+
use std::sync::{
13+
atomic::{AtomicUsize, Ordering},
14+
Arc,
15+
};
16+
use std::time::Duration;
17+
use tokio::runtime::Runtime;
18+
19+
const BINARY: bool = true;
20+
21+
fn format_size(size: u64, binary: bool) -> String {
22+
let units = if binary {
23+
["B", "KiB", "MiB", "GiB", "TiB"]
24+
} else {
25+
["B", "KB", "MB", "GB", "TB"]
26+
};
27+
let base = if binary { 1024.0 } else { 1000.0 };
28+
let mut size = size as f64;
29+
let mut unit_index = 0;
30+
31+
while size >= base && unit_index < units.len() - 1 {
32+
size /= base;
33+
unit_index += 1;
34+
}
35+
36+
format!("{:.2} {}", size, units[unit_index])
37+
}
38+
39+
/// S3 Downloader: Download all files from an S3 bucket with multiple threads.
40+
#[derive(Parser, Debug)]
41+
#[command(author, version, about, long_about = None)]
42+
struct Args {
43+
/// S3 bucket name
44+
#[arg(short, long)]
45+
bucket: String,
46+
47+
/// Local directory to download to
48+
#[arg(short, long)]
49+
output: String,
50+
51+
/// Number of worker threads
52+
#[arg(short, long, default_value_t = 4)]
53+
workers: usize,
54+
55+
/// Maximum number of retries for failed downloads
56+
#[arg(short, long, default_value_t = 3)]
57+
retries: u32,
58+
59+
/// File containing list of files to download (one per line)
60+
#[arg(short, long)]
61+
file_list: Option<String>,
62+
}
63+
64+
#[tokio::main]
65+
async fn main() {
66+
let args = Args::parse();
67+
68+
let bucket_name = args.bucket;
69+
let local_dir = PathBuf::from(args.output);
70+
let num_workers = args.workers;
71+
let max_retries = args.retries;
72+
73+
if !local_dir.exists() {
74+
fs::create_dir_all(&local_dir).expect("Failed to create output directory");
75+
}
76+
77+
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
78+
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
79+
.region(region_provider)
80+
.retry_config(RetryConfig::standard().with_max_attempts(max_retries))
81+
.load()
82+
.await;
83+
let client = Arc::new(Client::new(&config));
84+
85+
// Default file list name based on bucket
86+
let default_file_list = format!("{}.files.txt", bucket_name);
87+
88+
// Get list of files to download
89+
let keys = if let Some(file_list) = args.file_list {
90+
println!("Reading file list from: {}", file_list);
91+
let file = File::open(&file_list).expect("Failed to open file list");
92+
let reader = BufReader::new(file);
93+
reader.lines().map_while(Result::ok).collect()
94+
} else if PathBuf::from(&default_file_list).exists() {
95+
println!("Reading cached file list from: {}", default_file_list);
96+
let file = File::open(&default_file_list).expect("Failed to open cached file list");
97+
let reader = BufReader::new(file);
98+
reader.lines().map_while(Result::ok).collect()
99+
} else {
100+
println!("Listing objects in bucket: {}", bucket_name);
101+
let keys = list_objects(&client, &bucket_name).await;
102+
103+
// Always save the file list
104+
println!("Saving file list to: {}", default_file_list);
105+
let mut file = File::create(&default_file_list).expect("Failed to create file list");
106+
for key in &keys {
107+
writeln!(file, "{}", key).expect("Failed to write to file list");
108+
}
109+
println!("File list saved successfully");
110+
111+
keys
112+
};
113+
114+
println!(
115+
"Found {} files. Starting downloads with {} threads...",
116+
keys.len(),
117+
num_workers
118+
);
119+
rayon::ThreadPoolBuilder::new()
120+
.num_threads(num_workers)
121+
.build_global()
122+
.unwrap();
123+
124+
let m = Arc::new(MultiProgress::new());
125+
let downloaded = Arc::new(AtomicUsize::new(0));
126+
let failed = Arc::new(AtomicUsize::new(0));
127+
let downloaded_size = Arc::new(AtomicUsize::new(0));
128+
129+
// Create overall progress bar
130+
let total_pb = m.add(ProgressBar::new(keys.len() as u64));
131+
total_pb.set_style(
132+
ProgressStyle::with_template(
133+
"{spinner} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}",
134+
)
135+
.unwrap(),
136+
);
137+
138+
// Calculate files per thread
139+
let files_per_thread = keys.len().div_ceil(num_workers);
140+
141+
// Create fixed progress bars for each thread
142+
let thread_pbs: Vec<_> = (0..num_workers)
143+
.map(|i| {
144+
let pb = m.add(ProgressBar::new(files_per_thread as u64));
145+
pb.set_style(
146+
ProgressStyle::with_template("[{thread}] {spinner} [{elapsed_precise}] [{bar:40.yellow/blue}] {pos}/{len} ({percent}%) {msg}")
147+
.unwrap()
148+
);
149+
pb.set_message(format!("Thread {}: Starting", i + 1));
150+
pb
151+
})
152+
.collect();
153+
154+
keys.par_iter().enumerate().for_each(|(i, key)| {
155+
let client = Arc::clone(&client);
156+
let bucket = bucket_name.clone();
157+
let dir = local_dir.clone();
158+
let key = key.clone();
159+
let downloaded = Arc::clone(&downloaded);
160+
let failed = Arc::clone(&failed);
161+
let downloaded_size = Arc::clone(&downloaded_size);
162+
let total_pb = total_pb.clone();
163+
let thread_num = i % num_workers;
164+
let thread_pb = thread_pbs[thread_num].clone();
165+
166+
let rt = Runtime::new().unwrap();
167+
rt.block_on(async move {
168+
let local_path = dir.join(&key);
169+
if let Some(parent) = local_path.parent() {
170+
fs::create_dir_all(parent).expect("Failed to create parent directory");
171+
}
172+
173+
match download_object_with_retry(&client, &bucket, &key, max_retries).await {
174+
Ok(bytes) => {
175+
let mut file = File::create(&local_path).expect("Failed to create file");
176+
file.write_all(&bytes).expect("Failed to write file");
177+
downloaded.fetch_add(1, Ordering::SeqCst);
178+
downloaded_size.fetch_add(bytes.len(), Ordering::SeqCst);
179+
total_pb.inc(1);
180+
thread_pb.inc(1);
181+
thread_pb.set_message(format!(
182+
"Thread {}: Downloaded {}/{} files",
183+
thread_num + 1,
184+
downloaded.load(Ordering::SeqCst),
185+
files_per_thread
186+
));
187+
}
188+
Err(_e) => {
189+
failed.fetch_add(1, Ordering::SeqCst);
190+
total_pb.inc(1);
191+
thread_pb.inc(1);
192+
thread_pb.set_message(format!(
193+
"Thread {}: Failed {}/{} files",
194+
thread_num + 1,
195+
failed.load(Ordering::SeqCst),
196+
files_per_thread
197+
));
198+
}
199+
}
200+
});
201+
});
202+
203+
// Clean up all progress bars
204+
for pb in thread_pbs {
205+
pb.finish_and_clear();
206+
}
207+
total_pb.finish_with_message("Download complete");
208+
println!(
209+
"✅ Total files downloaded: {}",
210+
downloaded.load(Ordering::SeqCst)
211+
);
212+
println!("❌ Total files failed: {}", failed.load(Ordering::SeqCst));
213+
println!(
214+
"📦 Total data downloaded: {}",
215+
format_size(downloaded_size.load(Ordering::SeqCst) as u64, BINARY)
216+
);
217+
}
218+
219+
async fn list_objects(client: &Client, bucket: &str) -> Vec<String> {
220+
let mut keys = Vec::new();
221+
let mut continuation_token = None;
222+
223+
loop {
224+
let mut req = client.list_objects_v2().bucket(bucket.to_string());
225+
if let Some(token) = continuation_token {
226+
req = req.continuation_token(token);
227+
}
228+
229+
match req.send().await {
230+
Ok(resp) => {
231+
if let Some(contents) = resp.contents {
232+
for obj in contents {
233+
if let Some(key) = obj.key {
234+
keys.push(key);
235+
}
236+
}
237+
}
238+
239+
if resp.is_truncated.unwrap_or(false) {
240+
continuation_token = resp.next_continuation_token;
241+
} else {
242+
break;
243+
}
244+
}
245+
Err(_e) => {
246+
eprintln!("Failed to list objects: {:?}", _e);
247+
break;
248+
}
249+
}
250+
}
251+
252+
keys
253+
}
254+
255+
async fn download_object_with_retry(
256+
client: &Client,
257+
bucket: &str,
258+
key: &str,
259+
max_retries: u32,
260+
) -> Result<Vec<u8>> {
261+
let mut retry_count = 0;
262+
let mut last_error = None;
263+
264+
while retry_count <= max_retries {
265+
match download_object(client, bucket, key).await {
266+
Ok(bytes) => return Ok(bytes),
267+
Err(e) => {
268+
last_error = Some(e);
269+
retry_count += 1;
270+
if retry_count <= max_retries {
271+
tokio::time::sleep(Duration::from_secs(2u64.pow(retry_count))).await;
272+
}
273+
}
274+
}
275+
}
276+
277+
Err(last_error.unwrap_or_else(|| std::io::Error::other("Unknown error")))
278+
}
279+
280+
async fn download_object(client: &Client, bucket: &str, key: &str) -> Result<Vec<u8>> {
281+
let resp = client
282+
.get_object()
283+
.bucket(bucket)
284+
.key(key)
285+
.send()
286+
.await
287+
.map_err(std::io::Error::other)?;
288+
let data: ByteStream = resp.body;
289+
let bytes = data
290+
.collect()
291+
.await
292+
.map_err(std::io::Error::other)?
293+
.into_bytes()
294+
.to_vec();
295+
Ok(bytes)
296+
}

0 commit comments

Comments
 (0)