-
Notifications
You must be signed in to change notification settings - Fork 49
Refactor shard download to use local paths #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Updated the download function to save files locally and ensure directories exist before downloading.
WalkthroughImplements a safer, directory-aware download workflow in SharedShardedDataset.download_files: accepts local destination paths, derives S3 keys from basenames, ensures directories exist, streams objects without immediate load, downloads to temporary locations with progress, and atomically moves them to final paths. Adds shutil import and updates docstrings accordingly. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Caller
participant SSD as SharedShardedDataset
participant S3 as S3 Storage
participant TMP as Temp File
participant FS as Local FS
C->>SSD: download_files(local_tokens_path, local_ids_path)
SSD->>FS: ensure parent dirs exist
SSD->>SSD: derive s3 keys from basenames
SSD->>S3: s3_get_object(key, load_data=false, progress=true)
S3-->>SSD: temp file handle/path
SSD->>TMP: validate temp paths
SSD->>FS: shutil.move(tmp_tokens -> tokens_path)
SSD->>FS: shutil.move(tmp_ids -> ids_path)
SSD-->>C: return individual download results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tplr/sharded_dataset.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tplr/sharded_dataset.py (1)
src/tplr/comms.py (1)
s3_get_object(477-625)
🔇 Additional comments (1)
src/tplr/sharded_dataset.py (1)
20-20: LGTM! Necessary import for file operations.The
shutilimport is required for theshutil.moveoperations added in thedownload_filesmethod.
| # Extract just the filenames for S3 object keys | ||
| tokens_filename = os.path.basename(tokens_file) | ||
| ids_filename = os.path.basename(ids_file) | ||
|
|
||
| # Ensure the local directory exists | ||
| os.makedirs(os.path.dirname(tokens_file), exist_ok=True) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Design flaw: Workaround creates intermediate files in the current directory.
The approach of extracting basenames and passing them to s3_get_object while managing full paths separately creates a fragile two-step move operation:
s3_get_objectwithkey="train_000001.npy"(basename) downloads to temp, then moves to the current working directory- Line 294-296 performs a second move to the final destination
This workaround stems from s3_get_object conflating the S3 object key with the local destination path when load_data=False. The issues are:
- Intermediate files: Files land in the current directory before the second move, which can fail if the directory is not writable or on a different filesystem
- Efficiency: Two move operations instead of one direct temp→destination move
- Cleanup risk: If the second move fails, intermediate files remain in the current directory
Consider one of these solutions:
- Modify
s3_get_object(incomms.py) to accept separates3_keyandlocal_pathparameters whenload_data=False - Pass full paths to
s3_get_objectand modify it to derive S3 keys from basenames internally - Download to memory (avoid
load_data=False) and write to the desired location directly in this method
To verify the current directory behavior, run:
🏁 Script executed:
#!/bin/bash
# Verify where files land when s3_get_object is called with basenames
# Check s3_get_object implementation for load_data=False path
rg -A 10 'if not load_data:' src/tplr/comms.pyLength of output: 0
🏁 Script executed:
#!/bin/bash
rg -n 'def s3_get_object' -C5 .Length of output: 5573
🏁 Script executed:
#!/bin/bash
rg -n -C50 'async def s3_get_object' src/tplr/comms.pyLength of output: 11226
🏁 Script executed:
rg -n 'if load_data' -C10 src/tplr/comms.pyLength of output: 924
🏁 Script executed:
sed -n '550,650p' src/tplr/comms.pyLength of output: 4133
Refactor s3_get_object to accept an explicit download path
In src/tplr/sharded_dataset.py (lines 268–274), passing only the basename as key with load_data=False makes s3_get_object first move the file into the current working directory and then your code moves it again—this two-step move is fragile, inefficient, and leaves orphaned files on failure. Change s3_get_object to take separate s3_key and local_path parameters (or accept a full destination path), so it writes directly to the final directory.
| results = await asyncio.gather( | ||
| self.comms.s3_get_object( | ||
| tokens_file, | ||
| tokens_filename, # S3 object key (just filename) | ||
| bucket, | ||
| load_data=False, | ||
| show_progress=True, | ||
| ), | ||
| self.comms.s3_get_object( | ||
| ids_file, | ||
| ids_filename, # S3 object key (just filename) | ||
| bucket, | ||
| load_data=False, | ||
| show_progress=True, | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for download failures.
The code doesn't handle cases where s3_get_object returns None or a status dictionary (e.g., {"__status": "TOO_EARLY"}, {"__status": "TOO_LATE"}). Lines 293-296 will fail when checking os.path.exists(results[0]) if results[0] is None or a dict.
Add error handling after the asyncio.gather call:
results = await asyncio.gather(
self.comms.s3_get_object(
tokens_filename, # S3 object key (just filename)
bucket,
load_data=False,
show_progress=True,
),
self.comms.s3_get_object(
ids_filename, # S3 object key (just filename)
bucket,
load_data=False,
show_progress=True,
),
)
+
+# Validate download results
+for i, (result, filename) in enumerate(zip(results, [tokens_file, ids_file])):
+ if result is None:
+ raise RuntimeError(f"Failed to download {os.path.basename(filename)}")
+ if isinstance(result, dict) and "__status" in result:
+ raise RuntimeError(f"Download skipped for {os.path.basename(filename)}: {result['__status']}")Committable suggestion skipped: line range outside the PR's diff.
| # Move downloaded files to correct locations | ||
| # s3_get_object with load_data=False returns the temp file path | ||
| if results[0] and os.path.exists(results[0]): | ||
| shutil.move(results[0], tokens_file) | ||
| if results[1] and os.path.exists(results[1]): | ||
| shutil.move(results[1], ids_file) | ||
|
|
||
| return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve robustness of file move operations.
The current implementation has several issues:
- Type safety:
results[0]might beNone, a dict, or a string path - the code only handles the string case - Incomplete validation:
os.path.existscheck happens after potential type errors - Silent failures: If a file doesn't exist or move fails, the code continues without error
- Misleading return: Returns the intermediate paths from
s3_get_object, which no longer exist after the moves
Apply this diff to improve error handling and validation:
-# Move downloaded files to correct locations
-# s3_get_object with load_data=False returns the temp file path
-if results[0] and os.path.exists(results[0]):
- shutil.move(results[0], tokens_file)
-if results[1] and os.path.exists(results[1]):
- shutil.move(results[1], ids_file)
-
-return results
+# Move downloaded files to correct locations
+# s3_get_object with load_data=False returns the intermediate file path
+moved_files = []
+for result, final_path in zip(results, [tokens_file, ids_file]):
+ if not isinstance(result, str):
+ raise RuntimeError(f"Expected file path from download, got {type(result)}")
+ if not os.path.exists(result):
+ raise FileNotFoundError(f"Downloaded file not found: {result}")
+
+ try:
+ final_location = shutil.move(result, final_path)
+ moved_files.append(final_location)
+ except Exception as e:
+ raise RuntimeError(f"Failed to move {result} to {final_path}: {e}")
+
+return moved_files📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Move downloaded files to correct locations | |
| # s3_get_object with load_data=False returns the temp file path | |
| if results[0] and os.path.exists(results[0]): | |
| shutil.move(results[0], tokens_file) | |
| if results[1] and os.path.exists(results[1]): | |
| shutil.move(results[1], ids_file) | |
| return results | |
| # Move downloaded files to correct locations | |
| # s3_get_object with load_data=False returns the intermediate file path | |
| moved_files = [] | |
| for result, final_path in zip(results, [tokens_file, ids_file]): | |
| if not isinstance(result, str): | |
| raise RuntimeError(f"Expected file path from download, got {type(result)}") | |
| if not os.path.exists(result): | |
| raise FileNotFoundError(f"Downloaded file not found: {result}") | |
| try: | |
| final_location = shutil.move(result, final_path) | |
| moved_files.append(final_location) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to move {result} to {final_path}: {e}") | |
| return moved_files |
🤖 Prompt for AI Agents
In src/tplr/sharded_dataset.py around lines 291 to 298, the file-move logic is
fragile: results entries can be None or non-string, existence is checked too
late, failures are silent, and the function returns intermediate temp paths that
no longer exist; update the code to validate each result is a non-empty string
before touching the filesystem, check os.path.exists on that validated path,
wrap shutil.move in a try/except to raise or log a clear error on failure, and
after successful moves replace the corresponding entries in the returned value
with the final destination paths (tokens_file and ids_file) so the caller
receives current, correct paths — if a move cannot be completed, raise an
exception rather than silently continuing.
Updated the download function to save files locally and ensure directories exist before downloading.
Description
Related Issue(s)
Type of Change
Branch Naming
Commit Messages
Code Quality
Testing
Documentation
If this is a breaking change
Screenshots/Examples
Additional Notes
Summary by CodeRabbit
New Features
Bug Fixes
Documentation