diff --git a/projects/fal/src/fal/toolkit/__init__.py b/projects/fal/src/fal/toolkit/__init__.py index 51f8bad4..3e603a96 100644 --- a/projects/fal/src/fal/toolkit/__init__.py +++ b/projects/fal/src/fal/toolkit/__init__.py @@ -9,6 +9,7 @@ FAL_PERSISTENT_DIR, FAL_REPOSITORY_DIR, clone_repository, + clone_repository_cached, download_file, download_model_weights, ) @@ -31,6 +32,7 @@ "FAL_PERSISTENT_DIR", "FAL_REPOSITORY_DIR", "clone_repository", + "clone_repository_cached", "download_file", "download_model_weights", ] diff --git a/projects/fal/src/fal/toolkit/utils/download_utils.py b/projects/fal/src/fal/toolkit/utils/download_utils.py index 09abc0cd..612cee1f 100644 --- a/projects/fal/src/fal/toolkit/utils/download_utils.py +++ b/projects/fal/src/fal/toolkit/utils/download_utils.py @@ -3,6 +3,7 @@ import errno import hashlib import os +import random import shutil import subprocess import sys @@ -518,6 +519,175 @@ def clone_repository( return local_repo_path +# TODO: rename +def clone_repository_cached( + https_url: str, + *, + commit_hash: str | None = None, + target_path: str | Path | None = None, + include_to_path: bool = False, +) -> Path: + """Clones a Git repository from the specified HTTPS URL into a local + directory. + + This function clones a Git repository from the specified HTTPS URL into a local + directory. It can also checkout a specific commit if the `commit_hash` is provided. + + If a custom `target_path` is not specified, a predefined directory is + used for the target directory, and the repository name is determined from the URL. + + Args: + https_url: The HTTPS URL of the Git repository to be cloned. + commit_hash: The commit hash to checkout after cloning. + target_path: The path where the repository will be saved. + If not provided, a predefined directory is used. + include_to_path: If `True`, the cloned repository is added to the `sys.path`. + Defaults to `False`. + + Returns: + A Path object representing the full path to the cloned Git repository. + """ + + temp_dir = Path("/tmp") + base_repo_dir = Path(FAL_REPOSITORY_DIR) + if isinstance(target_path, str): + target_path = Path(target_path) + + repo_name = target_path.stem if target_path else Path(https_url).stem + + if commit_hash is None: + print( + "Warning: No commit hash provided. Attempting to fetch the latest" + " version of the repository from GitHub. This process may take time and" + " could result in unexpected changes. Please specify a commit hash to" + " ensure stability." + ) + + # Get the commit hash from the remote repository + commit_hash = subprocess.check_output( + ["git", "ls-remote", https_url, "HEAD"], + text=True, + stderr=subprocess.STDOUT, + ).split()[0] + if not commit_hash: + raise ValueError( + "Failed to get the commit hash from the remote repository." + ) + else: + # Convert mutable hash to immutable hash + result = subprocess.check_output( + ["git", "ls-remote", https_url, commit_hash], + text=True, + stderr=subprocess.STDOUT, + ) + + if result: + # This is mutable hash case + print( + "Warning: The provided Git reference is mutable (e.g., a branch or " + "tag). Please use an immutable commit hash to ensure reproducibility." + ) + commit_hash = result.split()[0] + + repo_name += f"-{commit_hash[:8]}" + + repo_hash = f"{_hash_url(https_url)}-{commit_hash}" + archive_path = base_repo_dir / (repo_hash + ".zip") + target_path = Path(target_path or temp_dir / repo_name) + + # Clean up the existing repository if it exists + if target_path.exists(): + with TemporaryDirectory( + dir=target_path.parent, suffix=f"{target_path.name}.tmp.old" + ) as tmp_dir: + with suppress(FileNotFoundError): + # repository might be already deleted by another worker + os.rename(target_path, tmp_dir) + # sometimes seeing FileNotFoundError even here on juicefs + shutil.rmtree(tmp_dir, ignore_errors=True) + + if archive_path.exists(): + print("Repository cache found, unpacking...") + + # Copy the archive to the temp directory + file_path = shutil.copyfile(archive_path, temp_dir / (repo_name + ".zip")) + + # Unpack and clean + shutil.unpack_archive(file_path, temp_dir / repo_name) + os.remove(file_path) + + if temp_dir.absolute() != target_path.parent.absolute(): + shutil.move(temp_dir / repo_name, target_path) + + else: + random_idx = random.randint(0, 9999999) + with TemporaryDirectory( + dir="/tmp", + suffix=f"{repo_name}.tmp{random_idx}", + ) as temp_repo_dir: + try: + print(f"Cloning the repository '{https_url}'.") + + # Clone with disabling the logs and advices for detached HEAD state. + clone_command = [ + "git", + "clone", + "--recursive", + https_url, + temp_repo_dir, + ] + subprocess.check_call(clone_command) + + if commit_hash: + checkout_command = ["git", "checkout", commit_hash] + subprocess.check_call(checkout_command, cwd=temp_repo_dir) + subprocess.check_call( + ["git", "submodule", "update", "--init", "--recursive"], + cwd=temp_repo_dir, + ) + + repo_zip_name = repo_hash + ".zip" + + file_name = shutil.make_archive( + repo_name, "zip", root_dir=temp_repo_dir + ) + os.rename(file_name, temp_dir / repo_zip_name) + + # We know that file_path is empty + os.makedirs(archive_path.parent, exist_ok=True) + shutil.move(temp_dir / repo_zip_name, archive_path) + + print(f"Repository is cached in {archive_path}") + + # NOTE: Atomically renaming the repository directory into place when the + # clone and checkout are done. + try: + shutil.move( + temp_repo_dir, + target_path.with_name(f"tmp_{random_idx}_" + target_path.name), + ) + os.rename( + target_path.with_name(f"tmp_{random_idx}_" + target_path.name), + target_path, + ) + except OSError as error: + shutil.rmtree(temp_repo_dir, ignore_errors=True) + + # someone beat us to it, assume it's good + if error.errno != errno.ENOTEMPTY: + raise + print(f"{target_path} already exists, skipping rename") + + except Exception as error: + print(f"{error}\nFailed to clone repository '{https_url}' .") + raise error + + if include_to_path: + __add_local_path_to_sys_path(target_path) + + return target_path + + def __add_local_path_to_sys_path(local_path: Path | str): local_path_str = str(local_path) diff --git a/projects/fal/tests/integration_test.py b/projects/fal/tests/integration_test.py index 321d1932..d68ac10a 100644 --- a/projects/fal/tests/integration_test.py +++ b/projects/fal/tests/integration_test.py @@ -16,6 +16,7 @@ from fal.toolkit import ( File, clone_repository, + clone_repository_cached, download_file, download_model_weights, ) @@ -347,24 +348,34 @@ def download_weights(): ), "The model weights should be redownloaded with force=True" -def test_clone_repository(isolated_client, mock_fal_persistent_dirs): - from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR - +@pytest.mark.parametrize( + "clone_fn", + [ + clone_repository, + clone_repository_cached, + ], +) +def test_clone_repository(isolated_client, mock_fal_persistent_dirs, clone_fn): # https://github.com/fal-ai/isolate/tree/64b0a89c8391bd2cb3ca23cdeae01779e11aee05 EXAMPLE_REPO_URL = "https://github.com/fal-ai/isolate.git" EXAMPLE_REPO_FIRST_COMMIT = "64b0a89c8391bd2cb3ca23cdeae01779e11aee05" EXAMPLE_REPO_SECOND_COMMIT = "34ecbca8cc7b64719d2a5c40dd3272f8d13bc1d2" - expected_path = FAL_REPOSITORY_DIR / "isolate" - first_expected_path = ( - FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}" - ) - second_expected_path = ( - FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}" - ) + from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR + + # clone_repository uses FAL_REPOSITORY_DIR, clone_repository_cached uses /tmp + if clone_fn == clone_repository: + base_dir = FAL_REPOSITORY_DIR + else: + base_dir = Path("/tmp") + + expected_path = str(base_dir / "isolate") + first_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}") + second_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}") + import os @isolated_client() def clone_without_commit_hash(): - repo_path = clone_repository(EXAMPLE_REPO_URL) + repo_path = clone_fn(EXAMPLE_REPO_URL) return repo_path @@ -373,14 +384,10 @@ def clone_without_commit_hash(): @isolated_client() def clone_with_commit_hash(): - first_path = clone_repository( - EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT - ) + first_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT) first_repo_hash = _git_rev_parse(first_path, "HEAD") - second_path = clone_repository( - EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT - ) + second_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT) second_repo_hash = _git_rev_parse(second_path, "HEAD") @@ -409,17 +416,17 @@ def clone_with_commit_hash(): @isolated_client() def clone_with_force(): - first_path = clone_repository( + first_path = clone_fn( EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False ) first_repo_stat = first_path.stat() - second_path = clone_repository( + second_path = clone_fn( EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False ) second_repo_stat = second_path.stat() - third_path = clone_repository( + third_path = clone_fn( EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=True ) third_repo_stat = third_path.stat() @@ -433,32 +440,81 @@ def clone_with_force(): third_repo_stat, ) - ( - first_path, - first_repo_stat, - second_path, - second_repo_stat, - third_path, - third_repo_stat, - ) = clone_with_force() + # Only test force functionality for clone_repository + if clone_fn == clone_repository: + ( + first_path, + first_repo_stat, + second_path, + second_repo_stat, + third_path, + third_repo_stat, + ) = clone_with_force() + + assert str(first_expected_path) == str( + first_path + ), "Path should be the target location" + assert str(first_expected_path) == str( + second_path + ), "Path should be the target location" + assert str(first_expected_path) == str( + third_path + ), "Path should be the target location" + + assert ( + first_repo_stat.st_mtime == second_repo_stat.st_mtime + ), "The repository should not be cloned again" + + assert ( + first_repo_stat.st_mtime < third_repo_stat.st_mtime + ), "The repository should be cloned again with force=True" - assert str(first_expected_path) == str( - first_path - ), "Path should be the target location" - assert str(first_expected_path) == str( - second_path - ), "Path should be the target location" - assert str(first_expected_path) == str( - third_path - ), "Path should be the target location" + @isolated_client() + def clone_without_commit_hash_multiple_times(): + import shutil - assert ( - first_repo_stat.st_mtime == second_repo_stat.st_mtime - ), "The repository should not be cloned again" + # Clean FAL_REPOSITORY_DIR + shutil.rmtree(FAL_REPOSITORY_DIR, ignore_errors=True) - assert ( - first_repo_stat.st_mtime < third_repo_stat.st_mtime - ), "The repository should be cloned again with force=True" + repo_path = clone_fn(EXAMPLE_REPO_URL) + + repo_dir_contents = os.listdir(FAL_REPOSITORY_DIR) + + first_n_archives = len(repo_dir_contents) + first_archive_stat = Path(FAL_REPOSITORY_DIR / repo_dir_contents[0]).stat() + + repo_path_2 = clone_fn(EXAMPLE_REPO_URL) + + repo_dir_contents_2 = os.listdir(FAL_REPOSITORY_DIR) + + second_n_archives = len(repo_dir_contents_2) + second_archive_stat = Path(FAL_REPOSITORY_DIR / repo_dir_contents_2[0]).stat() + + return ( + repo_path, + repo_path_2, + first_n_archives, + second_n_archives, + first_archive_stat, + second_archive_stat, + ) + + ( + repo_path, + repo_path_2, + first_n_archives, + second_n_archives, + first_archive_stat, + second_archive_stat, + ) = clone_without_commit_hash_multiple_times() + + if clone_fn == clone_repository_cached: + assert first_n_archives == 1, "Only one archive should be present" + assert second_n_archives == 1, "Only one archive should be present" + assert ( + first_archive_stat.st_mtime == second_archive_stat.st_mtime + ), "The archive should be the same" + assert repo_path == repo_path_2, "The repository should be the same" def fal_file_downloaded(file: File):