Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FAL_PERSISTENT_DIR,
FAL_REPOSITORY_DIR,
clone_repository,
clone_repository_cached,
download_file,
download_model_weights,
)
Expand All @@ -31,6 +32,7 @@
"FAL_PERSISTENT_DIR",
"FAL_REPOSITORY_DIR",
"clone_repository",
"clone_repository_cached",
"download_file",
"download_model_weights",
]
170 changes: 170 additions & 0 deletions projects/fal/src/fal/toolkit/utils/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import errno
import hashlib
import os
import random
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -518,6 +519,175 @@ def clone_repository(
return local_repo_path


# TODO: rename
def clone_repository_cached(
https_url: str,
*,
commit_hash: str | None = None,
target_path: str | Path | None = None,
include_to_path: bool = False,
) -> Path:
"""Clones a Git repository from the specified HTTPS URL into a local
directory.

This function clones a Git repository from the specified HTTPS URL into a local
directory. It can also checkout a specific commit if the `commit_hash` is provided.

If a custom `target_path` is not specified, a predefined directory is
used for the target directory, and the repository name is determined from the URL.

Args:
https_url: The HTTPS URL of the Git repository to be cloned.
commit_hash: The commit hash to checkout after cloning.
target_path: The path where the repository will be saved.
If not provided, a predefined directory is used.
include_to_path: If `True`, the cloned repository is added to the `sys.path`.
Defaults to `False`.

Returns:
A Path object representing the full path to the cloned Git repository.
"""

temp_dir = Path("/tmp")
base_repo_dir = Path(FAL_REPOSITORY_DIR)
if isinstance(target_path, str):
target_path = Path(target_path)

repo_name = target_path.stem if target_path else Path(https_url).stem

if commit_hash is None:
print(
"Warning: No commit hash provided. Attempting to fetch the latest"
" version of the repository from GitHub. This process may take time and"
" could result in unexpected changes. Please specify a commit hash to"
" ensure stability."
)

# Get the commit hash from the remote repository
commit_hash = subprocess.check_output(
["git", "ls-remote", https_url, "HEAD"],
text=True,
stderr=subprocess.STDOUT,
).split()[0]
if not commit_hash:
raise ValueError(
"Failed to get the commit hash from the remote repository."
)
else:
# Convert mutable hash to immutable hash
result = subprocess.check_output(
["git", "ls-remote", https_url, commit_hash],
text=True,
stderr=subprocess.STDOUT,
)

if result:
# This is mutable hash case
print(
"Warning: The provided Git reference is mutable (e.g., a branch or "
"tag). Please use an immutable commit hash to ensure reproducibility."
)
commit_hash = result.split()[0]

repo_name += f"-{commit_hash[:8]}"

repo_hash = f"{_hash_url(https_url)}-{commit_hash}"
archive_path = base_repo_dir / (repo_hash + ".zip")
target_path = Path(target_path or temp_dir / repo_name)

# Clean up the existing repository if it exists
if target_path.exists():
with TemporaryDirectory(
dir=target_path.parent, suffix=f"{target_path.name}.tmp.old"
) as tmp_dir:
with suppress(FileNotFoundError):
# repository might be already deleted by another worker
os.rename(target_path, tmp_dir)
# sometimes seeing FileNotFoundError even here on juicefs
shutil.rmtree(tmp_dir, ignore_errors=True)

if archive_path.exists():
print("Repository cache found, unpacking...")

# Copy the archive to the temp directory
file_path = shutil.copyfile(archive_path, temp_dir / (repo_name + ".zip"))

# Unpack and clean
shutil.unpack_archive(file_path, temp_dir / repo_name)
os.remove(file_path)

if temp_dir.absolute() != target_path.parent.absolute():
shutil.move(temp_dir / repo_name, target_path)

else:
random_idx = random.randint(0, 9999999)
with TemporaryDirectory(
dir="/tmp",
suffix=f"{repo_name}.tmp{random_idx}",
) as temp_repo_dir:
try:
print(f"Cloning the repository '{https_url}'.")

# Clone with disabling the logs and advices for detached HEAD state.
clone_command = [
"git",
"clone",
"--recursive",
https_url,
temp_repo_dir,
]
subprocess.check_call(clone_command)

if commit_hash:
checkout_command = ["git", "checkout", commit_hash]
subprocess.check_call(checkout_command, cwd=temp_repo_dir)
subprocess.check_call(
["git", "submodule", "update", "--init", "--recursive"],
cwd=temp_repo_dir,
)

repo_zip_name = repo_hash + ".zip"

file_name = shutil.make_archive(
repo_name, "zip", root_dir=temp_repo_dir
)
os.rename(file_name, temp_dir / repo_zip_name)

# We know that file_path is empty
os.makedirs(archive_path.parent, exist_ok=True)
shutil.move(temp_dir / repo_zip_name, archive_path)

print(f"Repository is cached in {archive_path}")

# NOTE: Atomically renaming the repository directory into place when the
# clone and checkout are done.
try:
shutil.move(
temp_repo_dir,
target_path.with_name(f"tmp_{random_idx}_" + target_path.name),
)
os.rename(
target_path.with_name(f"tmp_{random_idx}_" + target_path.name),
target_path,
)
except OSError as error:
shutil.rmtree(temp_repo_dir, ignore_errors=True)

# someone beat us to it, assume it's good
if error.errno != errno.ENOTEMPTY:
raise
print(f"{target_path} already exists, skipping rename")

except Exception as error:
print(f"{error}\nFailed to clone repository '{https_url}' .")
raise error

if include_to_path:
__add_local_path_to_sys_path(target_path)

return target_path


def __add_local_path_to_sys_path(local_path: Path | str):
local_path_str = str(local_path)

Expand Down
142 changes: 99 additions & 43 deletions projects/fal/tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fal.toolkit import (
File,
clone_repository,
clone_repository_cached,
download_file,
download_model_weights,
)
Expand Down Expand Up @@ -347,24 +348,34 @@ def download_weights():
), "The model weights should be redownloaded with force=True"


def test_clone_repository(isolated_client, mock_fal_persistent_dirs):
from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR

@pytest.mark.parametrize(
"clone_fn",
[
clone_repository,
clone_repository_cached,
],
)
def test_clone_repository(isolated_client, mock_fal_persistent_dirs, clone_fn):
# https://github.com/fal-ai/isolate/tree/64b0a89c8391bd2cb3ca23cdeae01779e11aee05
EXAMPLE_REPO_URL = "https://github.com/fal-ai/isolate.git"
EXAMPLE_REPO_FIRST_COMMIT = "64b0a89c8391bd2cb3ca23cdeae01779e11aee05"
EXAMPLE_REPO_SECOND_COMMIT = "34ecbca8cc7b64719d2a5c40dd3272f8d13bc1d2"
expected_path = FAL_REPOSITORY_DIR / "isolate"
first_expected_path = (
FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}"
)
second_expected_path = (
FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}"
)
from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR

# clone_repository uses FAL_REPOSITORY_DIR, clone_repository_cached uses /tmp
if clone_fn == clone_repository:
base_dir = FAL_REPOSITORY_DIR
else:
base_dir = Path("/tmp")

expected_path = str(base_dir / "isolate")
first_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}")
second_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}")
import os

@isolated_client()
def clone_without_commit_hash():
repo_path = clone_repository(EXAMPLE_REPO_URL)
repo_path = clone_fn(EXAMPLE_REPO_URL)

return repo_path

Expand All @@ -373,14 +384,10 @@ def clone_without_commit_hash():

@isolated_client()
def clone_with_commit_hash():
first_path = clone_repository(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT
)
first_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT)
first_repo_hash = _git_rev_parse(first_path, "HEAD")

second_path = clone_repository(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT
)
second_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT)

second_repo_hash = _git_rev_parse(second_path, "HEAD")

Expand Down Expand Up @@ -409,17 +416,17 @@ def clone_with_commit_hash():

@isolated_client()
def clone_with_force():
first_path = clone_repository(
first_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False
)
first_repo_stat = first_path.stat()

second_path = clone_repository(
second_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False
)
second_repo_stat = second_path.stat()

third_path = clone_repository(
third_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=True
)
third_repo_stat = third_path.stat()
Expand All @@ -433,32 +440,81 @@ def clone_with_force():
third_repo_stat,
)

(
first_path,
first_repo_stat,
second_path,
second_repo_stat,
third_path,
third_repo_stat,
) = clone_with_force()
# Only test force functionality for clone_repository
if clone_fn == clone_repository:
(
first_path,
first_repo_stat,
second_path,
second_repo_stat,
third_path,
third_repo_stat,
) = clone_with_force()

assert str(first_expected_path) == str(
first_path
), "Path should be the target location"
assert str(first_expected_path) == str(
second_path
), "Path should be the target location"
assert str(first_expected_path) == str(
third_path
), "Path should be the target location"

assert (
first_repo_stat.st_mtime == second_repo_stat.st_mtime
), "The repository should not be cloned again"

assert (
first_repo_stat.st_mtime < third_repo_stat.st_mtime
), "The repository should be cloned again with force=True"

assert str(first_expected_path) == str(
first_path
), "Path should be the target location"
assert str(first_expected_path) == str(
second_path
), "Path should be the target location"
assert str(first_expected_path) == str(
third_path
), "Path should be the target location"
@isolated_client()
def clone_without_commit_hash_multiple_times():
import shutil

assert (
first_repo_stat.st_mtime == second_repo_stat.st_mtime
), "The repository should not be cloned again"
# Clean FAL_REPOSITORY_DIR
shutil.rmtree(FAL_REPOSITORY_DIR, ignore_errors=True)

assert (
first_repo_stat.st_mtime < third_repo_stat.st_mtime
), "The repository should be cloned again with force=True"
repo_path = clone_fn(EXAMPLE_REPO_URL)

repo_dir_contents = os.listdir(FAL_REPOSITORY_DIR)

first_n_archives = len(repo_dir_contents)
first_archive_stat = Path(FAL_REPOSITORY_DIR / repo_dir_contents[0]).stat()

repo_path_2 = clone_fn(EXAMPLE_REPO_URL)

repo_dir_contents_2 = os.listdir(FAL_REPOSITORY_DIR)

second_n_archives = len(repo_dir_contents_2)
second_archive_stat = Path(FAL_REPOSITORY_DIR / repo_dir_contents_2[0]).stat()

return (
repo_path,
repo_path_2,
first_n_archives,
second_n_archives,
first_archive_stat,
second_archive_stat,
)

(
repo_path,
repo_path_2,
first_n_archives,
second_n_archives,
first_archive_stat,
second_archive_stat,
) = clone_without_commit_hash_multiple_times()

if clone_fn == clone_repository_cached:
assert first_n_archives == 1, "Only one archive should be present"
assert second_n_archives == 1, "Only one archive should be present"
assert (
first_archive_stat.st_mtime == second_archive_stat.st_mtime
), "The archive should be the same"
assert repo_path == repo_path_2, "The repository should be the same"


def fal_file_downloaded(file: File):
Expand Down
Loading