|
| 1 | +use aws_config::meta::region::RegionProviderChain; |
| 2 | +use aws_config::retry::RetryConfig; |
| 3 | +use aws_sdk_s3::primitives::ByteStream; |
| 4 | +use aws_sdk_s3::Client; |
| 5 | +use clap::Parser; |
| 6 | +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; |
| 7 | +use rayon::prelude::*; |
| 8 | +use std::fs::{self, File}; |
| 9 | +use std::io::Result; |
| 10 | +use std::io::{BufRead, BufReader, Write}; |
| 11 | +use std::path::PathBuf; |
| 12 | +use std::sync::{ |
| 13 | + atomic::{AtomicUsize, Ordering}, |
| 14 | + Arc, |
| 15 | +}; |
| 16 | +use std::time::Duration; |
| 17 | +use tokio::runtime::Runtime; |
| 18 | + |
| 19 | +const BINARY: bool = true; |
| 20 | + |
| 21 | +fn format_size(size: u64, binary: bool) -> String { |
| 22 | + let units = if binary { |
| 23 | + ["B", "KiB", "MiB", "GiB", "TiB"] |
| 24 | + } else { |
| 25 | + ["B", "KB", "MB", "GB", "TB"] |
| 26 | + }; |
| 27 | + let base = if binary { 1024.0 } else { 1000.0 }; |
| 28 | + let mut size = size as f64; |
| 29 | + let mut unit_index = 0; |
| 30 | + |
| 31 | + while size >= base && unit_index < units.len() - 1 { |
| 32 | + size /= base; |
| 33 | + unit_index += 1; |
| 34 | + } |
| 35 | + |
| 36 | + format!("{:.2} {}", size, units[unit_index]) |
| 37 | +} |
| 38 | + |
| 39 | +/// S3 Downloader: Download all files from an S3 bucket with multiple threads. |
| 40 | +#[derive(Parser, Debug)] |
| 41 | +#[command(author, version, about, long_about = None)] |
| 42 | +struct Args { |
| 43 | + /// S3 bucket name |
| 44 | + #[arg(short, long)] |
| 45 | + bucket: String, |
| 46 | + |
| 47 | + /// Local directory to download to |
| 48 | + #[arg(short, long)] |
| 49 | + output: String, |
| 50 | + |
| 51 | + /// Number of worker threads |
| 52 | + #[arg(short, long, default_value_t = 4)] |
| 53 | + workers: usize, |
| 54 | + |
| 55 | + /// Maximum number of retries for failed downloads |
| 56 | + #[arg(short, long, default_value_t = 3)] |
| 57 | + retries: u32, |
| 58 | + |
| 59 | + /// File containing list of files to download (one per line) |
| 60 | + #[arg(short, long)] |
| 61 | + file_list: Option<String>, |
| 62 | +} |
| 63 | + |
| 64 | +#[tokio::main] |
| 65 | +async fn main() { |
| 66 | + let args = Args::parse(); |
| 67 | + |
| 68 | + let bucket_name = args.bucket; |
| 69 | + let local_dir = PathBuf::from(args.output); |
| 70 | + let num_workers = args.workers; |
| 71 | + let max_retries = args.retries; |
| 72 | + |
| 73 | + if !local_dir.exists() { |
| 74 | + fs::create_dir_all(&local_dir).expect("Failed to create output directory"); |
| 75 | + } |
| 76 | + |
| 77 | + let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); |
| 78 | + let config = aws_config::defaults(aws_config::BehaviorVersion::latest()) |
| 79 | + .region(region_provider) |
| 80 | + .retry_config(RetryConfig::standard().with_max_attempts(max_retries)) |
| 81 | + .load() |
| 82 | + .await; |
| 83 | + let client = Arc::new(Client::new(&config)); |
| 84 | + |
| 85 | + // Default file list name based on bucket |
| 86 | + let default_file_list = format!("{}.files.txt", bucket_name); |
| 87 | + |
| 88 | + // Get list of files to download |
| 89 | + let keys = if let Some(file_list) = args.file_list { |
| 90 | + println!("Reading file list from: {}", file_list); |
| 91 | + let file = File::open(&file_list).expect("Failed to open file list"); |
| 92 | + let reader = BufReader::new(file); |
| 93 | + reader.lines().map_while(Result::ok).collect() |
| 94 | + } else if PathBuf::from(&default_file_list).exists() { |
| 95 | + println!("Reading cached file list from: {}", default_file_list); |
| 96 | + let file = File::open(&default_file_list).expect("Failed to open cached file list"); |
| 97 | + let reader = BufReader::new(file); |
| 98 | + reader.lines().map_while(Result::ok).collect() |
| 99 | + } else { |
| 100 | + println!("Listing objects in bucket: {}", bucket_name); |
| 101 | + let keys = list_objects(&client, &bucket_name).await; |
| 102 | + |
| 103 | + // Always save the file list |
| 104 | + println!("Saving file list to: {}", default_file_list); |
| 105 | + let mut file = File::create(&default_file_list).expect("Failed to create file list"); |
| 106 | + for key in &keys { |
| 107 | + writeln!(file, "{}", key).expect("Failed to write to file list"); |
| 108 | + } |
| 109 | + println!("File list saved successfully"); |
| 110 | + |
| 111 | + keys |
| 112 | + }; |
| 113 | + |
| 114 | + println!( |
| 115 | + "Found {} files. Starting downloads with {} threads...", |
| 116 | + keys.len(), |
| 117 | + num_workers |
| 118 | + ); |
| 119 | + rayon::ThreadPoolBuilder::new() |
| 120 | + .num_threads(num_workers) |
| 121 | + .build_global() |
| 122 | + .unwrap(); |
| 123 | + |
| 124 | + let m = Arc::new(MultiProgress::new()); |
| 125 | + let downloaded = Arc::new(AtomicUsize::new(0)); |
| 126 | + let failed = Arc::new(AtomicUsize::new(0)); |
| 127 | + let downloaded_size = Arc::new(AtomicUsize::new(0)); |
| 128 | + |
| 129 | + // Create overall progress bar |
| 130 | + let total_pb = m.add(ProgressBar::new(keys.len() as u64)); |
| 131 | + total_pb.set_style( |
| 132 | + ProgressStyle::with_template( |
| 133 | + "{spinner} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}", |
| 134 | + ) |
| 135 | + .unwrap(), |
| 136 | + ); |
| 137 | + |
| 138 | + // Calculate files per thread |
| 139 | + let files_per_thread = keys.len().div_ceil(num_workers); |
| 140 | + |
| 141 | + // Create fixed progress bars for each thread |
| 142 | + let thread_pbs: Vec<_> = (0..num_workers) |
| 143 | + .map(|i| { |
| 144 | + let pb = m.add(ProgressBar::new(files_per_thread as u64)); |
| 145 | + pb.set_style( |
| 146 | + ProgressStyle::with_template("[{thread}] {spinner} [{elapsed_precise}] [{bar:40.yellow/blue}] {pos}/{len} ({percent}%) {msg}") |
| 147 | + .unwrap() |
| 148 | + ); |
| 149 | + pb.set_message(format!("Thread {}: Starting", i + 1)); |
| 150 | + pb |
| 151 | + }) |
| 152 | + .collect(); |
| 153 | + |
| 154 | + keys.par_iter().enumerate().for_each(|(i, key)| { |
| 155 | + let client = Arc::clone(&client); |
| 156 | + let bucket = bucket_name.clone(); |
| 157 | + let dir = local_dir.clone(); |
| 158 | + let key = key.clone(); |
| 159 | + let downloaded = Arc::clone(&downloaded); |
| 160 | + let failed = Arc::clone(&failed); |
| 161 | + let downloaded_size = Arc::clone(&downloaded_size); |
| 162 | + let total_pb = total_pb.clone(); |
| 163 | + let thread_num = i % num_workers; |
| 164 | + let thread_pb = thread_pbs[thread_num].clone(); |
| 165 | + |
| 166 | + let rt = Runtime::new().unwrap(); |
| 167 | + rt.block_on(async move { |
| 168 | + let local_path = dir.join(&key); |
| 169 | + if let Some(parent) = local_path.parent() { |
| 170 | + fs::create_dir_all(parent).expect("Failed to create parent directory"); |
| 171 | + } |
| 172 | + |
| 173 | + match download_object_with_retry(&client, &bucket, &key, max_retries).await { |
| 174 | + Ok(bytes) => { |
| 175 | + let mut file = File::create(&local_path).expect("Failed to create file"); |
| 176 | + file.write_all(&bytes).expect("Failed to write file"); |
| 177 | + downloaded.fetch_add(1, Ordering::SeqCst); |
| 178 | + downloaded_size.fetch_add(bytes.len(), Ordering::SeqCst); |
| 179 | + total_pb.inc(1); |
| 180 | + thread_pb.inc(1); |
| 181 | + thread_pb.set_message(format!( |
| 182 | + "Thread {}: Downloaded {}/{} files", |
| 183 | + thread_num + 1, |
| 184 | + downloaded.load(Ordering::SeqCst), |
| 185 | + files_per_thread |
| 186 | + )); |
| 187 | + } |
| 188 | + Err(_e) => { |
| 189 | + failed.fetch_add(1, Ordering::SeqCst); |
| 190 | + total_pb.inc(1); |
| 191 | + thread_pb.inc(1); |
| 192 | + thread_pb.set_message(format!( |
| 193 | + "Thread {}: Failed {}/{} files", |
| 194 | + thread_num + 1, |
| 195 | + failed.load(Ordering::SeqCst), |
| 196 | + files_per_thread |
| 197 | + )); |
| 198 | + } |
| 199 | + } |
| 200 | + }); |
| 201 | + }); |
| 202 | + |
| 203 | + // Clean up all progress bars |
| 204 | + for pb in thread_pbs { |
| 205 | + pb.finish_and_clear(); |
| 206 | + } |
| 207 | + total_pb.finish_with_message("Download complete"); |
| 208 | + println!( |
| 209 | + "✅ Total files downloaded: {}", |
| 210 | + downloaded.load(Ordering::SeqCst) |
| 211 | + ); |
| 212 | + println!("❌ Total files failed: {}", failed.load(Ordering::SeqCst)); |
| 213 | + println!( |
| 214 | + "📦 Total data downloaded: {}", |
| 215 | + format_size(downloaded_size.load(Ordering::SeqCst) as u64, BINARY) |
| 216 | + ); |
| 217 | +} |
| 218 | + |
| 219 | +async fn list_objects(client: &Client, bucket: &str) -> Vec<String> { |
| 220 | + let mut keys = Vec::new(); |
| 221 | + let mut continuation_token = None; |
| 222 | + |
| 223 | + loop { |
| 224 | + let mut req = client.list_objects_v2().bucket(bucket.to_string()); |
| 225 | + if let Some(token) = continuation_token { |
| 226 | + req = req.continuation_token(token); |
| 227 | + } |
| 228 | + |
| 229 | + match req.send().await { |
| 230 | + Ok(resp) => { |
| 231 | + if let Some(contents) = resp.contents { |
| 232 | + for obj in contents { |
| 233 | + if let Some(key) = obj.key { |
| 234 | + keys.push(key); |
| 235 | + } |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + if resp.is_truncated.unwrap_or(false) { |
| 240 | + continuation_token = resp.next_continuation_token; |
| 241 | + } else { |
| 242 | + break; |
| 243 | + } |
| 244 | + } |
| 245 | + Err(_e) => { |
| 246 | + eprintln!("Failed to list objects: {:?}", _e); |
| 247 | + break; |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + keys |
| 253 | +} |
| 254 | + |
| 255 | +async fn download_object_with_retry( |
| 256 | + client: &Client, |
| 257 | + bucket: &str, |
| 258 | + key: &str, |
| 259 | + max_retries: u32, |
| 260 | +) -> Result<Vec<u8>> { |
| 261 | + let mut retry_count = 0; |
| 262 | + let mut last_error = None; |
| 263 | + |
| 264 | + while retry_count <= max_retries { |
| 265 | + match download_object(client, bucket, key).await { |
| 266 | + Ok(bytes) => return Ok(bytes), |
| 267 | + Err(e) => { |
| 268 | + last_error = Some(e); |
| 269 | + retry_count += 1; |
| 270 | + if retry_count <= max_retries { |
| 271 | + tokio::time::sleep(Duration::from_secs(2u64.pow(retry_count))).await; |
| 272 | + } |
| 273 | + } |
| 274 | + } |
| 275 | + } |
| 276 | + |
| 277 | + Err(last_error.unwrap_or_else(|| std::io::Error::other("Unknown error"))) |
| 278 | +} |
| 279 | + |
| 280 | +async fn download_object(client: &Client, bucket: &str, key: &str) -> Result<Vec<u8>> { |
| 281 | + let resp = client |
| 282 | + .get_object() |
| 283 | + .bucket(bucket) |
| 284 | + .key(key) |
| 285 | + .send() |
| 286 | + .await |
| 287 | + .map_err(std::io::Error::other)?; |
| 288 | + let data: ByteStream = resp.body; |
| 289 | + let bytes = data |
| 290 | + .collect() |
| 291 | + .await |
| 292 | + .map_err(std::io::Error::other)? |
| 293 | + .into_bytes() |
| 294 | + .to_vec(); |
| 295 | + Ok(bytes) |
| 296 | +} |
0 commit comments