Skip to content

Commit 814130f

Browse files
authored
Merge pull request #14 from sangshuduo/feat/sangshuduo/s3-bucket-downloader
fix(s3_bucket_downloader): add redownload mode for missing files
2 parents 4601a3d + b457c88 commit 814130f

File tree

1 file changed

+78
-52
lines changed

1 file changed

+78
-52
lines changed

s3_bucket_downloader/src/main.rs

+78-52
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ struct Args {
5959
/// File containing list of files to download (one per line)
6060
#[arg(short, long)]
6161
file_list: Option<String>,
62+
63+
/// Redownload mode: check existing files and only download missing ones
64+
#[arg(short = 'd', long, default_value_t = false)]
65+
redownload: bool,
6266
}
6367

6468
#[tokio::main]
@@ -69,6 +73,7 @@ async fn main() {
6973
let local_dir = PathBuf::from(args.output);
7074
let num_workers = args.workers;
7175
let max_retries = args.retries;
76+
let redownload = args.redownload;
7277

7378
if !local_dir.exists() {
7479
fs::create_dir_all(&local_dir).expect("Failed to create output directory");
@@ -111,11 +116,34 @@ async fn main() {
111116
keys
112117
};
113118

114-
println!(
115-
"Found {} files. Starting downloads with {} threads...",
116-
keys.len(),
117-
num_workers
118-
);
119+
// Filter out existing files in redownload mode
120+
let keys_to_download = if redownload {
121+
let mut missing_keys = Vec::new();
122+
let mut existing_count = 0;
123+
for key in &keys {
124+
let local_path = local_dir.join(key);
125+
if local_path.exists() {
126+
existing_count += 1;
127+
} else {
128+
missing_keys.push(key.clone());
129+
}
130+
}
131+
println!(
132+
"Found {} existing files, {} files to download",
133+
existing_count,
134+
missing_keys.len()
135+
);
136+
missing_keys
137+
} else {
138+
keys
139+
};
140+
141+
if keys_to_download.is_empty() {
142+
println!("No files to download.");
143+
return;
144+
}
145+
146+
println!("Starting downloads with {} threads...", num_workers);
119147
rayon::ThreadPoolBuilder::new()
120148
.num_threads(num_workers)
121149
.build_global()
@@ -127,7 +155,7 @@ async fn main() {
127155
let downloaded_size = Arc::new(AtomicUsize::new(0));
128156

129157
// Create overall progress bar
130-
let total_pb = m.add(ProgressBar::new(keys.len() as u64));
158+
let total_pb = m.add(ProgressBar::new(keys_to_download.len() as u64));
131159
total_pb.set_style(
132160
ProgressStyle::with_template(
133161
"{spinner} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}",
@@ -136,7 +164,7 @@ async fn main() {
136164
);
137165

138166
// Calculate files per thread
139-
let files_per_thread = keys.len().div_ceil(num_workers);
167+
let files_per_thread = keys_to_download.len().div_ceil(num_workers);
140168

141169
// Create fixed progress bars for each thread
142170
let thread_pbs: Vec<_> = (0..num_workers)
@@ -151,54 +179,52 @@ async fn main() {
151179
})
152180
.collect();
153181

154-
keys.par_iter().enumerate().for_each(|(i, key)| {
155-
let client = Arc::clone(&client);
156-
let bucket = bucket_name.clone();
157-
let dir = local_dir.clone();
158-
let key = key.clone();
159-
let downloaded = Arc::clone(&downloaded);
160-
let failed = Arc::clone(&failed);
161-
let downloaded_size = Arc::clone(&downloaded_size);
162-
let total_pb = total_pb.clone();
163-
let thread_num = i % num_workers;
164-
let thread_pb = thread_pbs[thread_num].clone();
165-
166-
let rt = Runtime::new().unwrap();
167-
rt.block_on(async move {
168-
let local_path = dir.join(&key);
169-
if let Some(parent) = local_path.parent() {
170-
fs::create_dir_all(parent).expect("Failed to create parent directory");
171-
}
172-
173-
match download_object_with_retry(&client, &bucket, &key, max_retries).await {
174-
Ok(bytes) => {
175-
let mut file = File::create(&local_path).expect("Failed to create file");
176-
file.write_all(&bytes).expect("Failed to write file");
177-
downloaded.fetch_add(1, Ordering::SeqCst);
178-
downloaded_size.fetch_add(bytes.len(), Ordering::SeqCst);
179-
total_pb.inc(1);
180-
thread_pb.inc(1);
181-
thread_pb.set_message(format!(
182-
"Thread {}: Downloaded {}/{} files",
183-
thread_num + 1,
184-
downloaded.load(Ordering::SeqCst),
185-
files_per_thread
186-
));
182+
keys_to_download
183+
.par_iter()
184+
.enumerate()
185+
.for_each(|(i, key)| {
186+
let client = Arc::clone(&client);
187+
let bucket = bucket_name.clone();
188+
let dir = local_dir.clone();
189+
let key = key.clone();
190+
let downloaded = Arc::clone(&downloaded);
191+
let failed = Arc::clone(&failed);
192+
let downloaded_size = Arc::clone(&downloaded_size);
193+
let total_pb = total_pb.clone();
194+
let thread_num = i % num_workers;
195+
let thread_pb = thread_pbs[thread_num].clone();
196+
197+
let rt = Runtime::new().unwrap();
198+
rt.block_on(async move {
199+
let local_path = dir.join(&key);
200+
if let Some(parent) = local_path.parent() {
201+
fs::create_dir_all(parent).expect("Failed to create parent directory");
187202
}
188-
Err(_e) => {
189-
failed.fetch_add(1, Ordering::SeqCst);
190-
total_pb.inc(1);
191-
thread_pb.inc(1);
192-
thread_pb.set_message(format!(
193-
"Thread {}: Failed {}/{} files",
194-
thread_num + 1,
195-
failed.load(Ordering::SeqCst),
196-
files_per_thread
197-
));
203+
204+
match download_object_with_retry(&client, &bucket, &key, max_retries).await {
205+
Ok(bytes) => {
206+
let mut file = File::create(&local_path).expect("Failed to create file");
207+
file.write_all(&bytes).expect("Failed to write file");
208+
downloaded.fetch_add(1, Ordering::SeqCst);
209+
downloaded_size.fetch_add(bytes.len(), Ordering::SeqCst);
210+
total_pb.inc(1);
211+
thread_pb.inc(1);
212+
thread_pb.set_message(format!("by Thread {}", thread_num + 1));
213+
}
214+
Err(_e) => {
215+
failed.fetch_add(1, Ordering::SeqCst);
216+
total_pb.inc(1);
217+
thread_pb.inc(1);
218+
thread_pb.set_message(format!(
219+
"Thread {}: Failed {}/{} files",
220+
thread_num + 1,
221+
failed.load(Ordering::SeqCst),
222+
files_per_thread
223+
));
224+
}
198225
}
199-
}
226+
});
200227
});
201-
});
202228

203229
// Clean up all progress bars
204230
for pb in thread_pbs {

0 commit comments

Comments
 (0)