Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
FAL_PERSISTENT_DIR,
FAL_REPOSITORY_DIR,
clone_repository,
clone_repository_cached,
download_file,
download_model_weights,
)
Expand All @@ -31,6 +32,7 @@
"FAL_PERSISTENT_DIR",
"FAL_REPOSITORY_DIR",
"clone_repository",
"clone_repository_cached",
"download_file",
"download_model_weights",
]
138 changes: 133 additions & 5 deletions projects/fal/src/fal/toolkit/utils/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import errno
import hashlib
import os
import random
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -400,13 +401,10 @@ def clone_repository(
) -> Path:
"""Clones a Git repository from the specified HTTPS URL into a local
directory.

This function clones a Git repository from the specified HTTPS URL into a local
directory. It can also checkout a specific commit if the `commit_hash` is provided.

If a custom `target_dir` or `repo_name` is not specified, a predefined directory is
used for the target directory, and the repository name is determined from the URL.

Args:
https_url: The HTTPS URL of the Git repository to be cloned.
commit_hash: The commit hash to checkout after cloning.
Expand All @@ -418,7 +416,6 @@ def clone_repository(
and its commit hash matches the provided commit hash. Defaults to `False`.
include_to_path: If `True`, the cloned repository is added to the `sys.path`.
Defaults to `False`.

Returns:
A Path object representing the full path to the cloned Git repository.
"""
Expand Down Expand Up @@ -505,10 +502,141 @@ def clone_repository(

if include_to_path:
__add_local_path_to_sys_path(local_repo_path)

return local_repo_path


# TODO: rename
def clone_repository_cached(
https_url: str,
*,
commit_hash: str | None = None,
target_path: str | Path | None = None,
include_to_path: bool = False,
) -> Path:
"""Clones a Git repository from the specified HTTPS URL into a local
directory.

This function clones a Git repository from the specified HTTPS URL into a local
directory. It can also checkout a specific commit if the `commit_hash` is provided.

If a custom `target_path` is not specified, a predefined directory is
used for the target directory, and the repository name is determined from the URL.

Args:
https_url: The HTTPS URL of the Git repository to be cloned.
commit_hash: The commit hash to checkout after cloning.
target_path: The path where the repository will be saved.
If not provided, a predefined directory is used.
include_to_path: If `True`, the cloned repository is added to the `sys.path`.
Defaults to `False`.

Returns:
A Path object representing the full path to the cloned Git repository.
"""

temp_dir = Path("/tmp")
base_repo_dir = Path(FAL_REPOSITORY_DIR)
if isinstance(target_path, str):
target_path = Path(target_path)

repo_name = target_path.stem if target_path else Path(https_url).stem
if commit_hash:
repo_name += f"-{commit_hash[:8]}"

commit_hash = commit_hash or "main"
repo_hash = f"{_hash_url(https_url)}-{commit_hash}"
archive_path = base_repo_dir / (repo_hash + ".zip")
target_path = Path(target_path or temp_dir / repo_name)

# Clean up the existing repository if it exists
if target_path.exists():
with TemporaryDirectory(
dir=target_path.parent, suffix=f"{target_path.name}.tmp.old"
) as tmp_dir:
with suppress(FileNotFoundError):
# repository might be already deleted by another worker
os.rename(target_path, tmp_dir)
# sometimes seeing FileNotFoundError even here on juicefs
shutil.rmtree(tmp_dir, ignore_errors=True)

if archive_path.exists():
print("Cached repository found, unpacking...")

# Copy the archive to the temp directory
file_path = shutil.copyfile(archive_path, temp_dir / (repo_name + ".zip"))

# Unpack and clean
shutil.unpack_archive(file_path, temp_dir / repo_name)
os.remove(file_path)

if temp_dir.absolute() != target_path.parent.absolute():
shutil.move(temp_dir / repo_name, target_path)

else:
# NOTE: using the target_dir to be able to avoid potential copies across temp fs
# and target fs, and also to be able to atomically rename repo_name dir into
# place when we are done setting it up.
# os.makedirs(target_dir, exist_ok=True) # type: ignore[arg-type]
with TemporaryDirectory(
dir="/tmp",
suffix=f"{repo_name}.tmp{random.randint(0, 1000000)}",
) as temp_repo_dir:
try:
print(f"Cloning the repository '{https_url}'.")

# Clone with disabling the logs and advices for detached HEAD state.
clone_command = [
"git",
"clone",
"--recursive",
https_url,
temp_repo_dir,
]
subprocess.check_call(clone_command)

if commit_hash:
checkout_command = ["git", "checkout", commit_hash]
subprocess.check_call(checkout_command, cwd=temp_repo_dir)
subprocess.check_call(
["git", "submodule", "update", "--init", "--recursive"],
cwd=temp_repo_dir,
)

repo_zip_name = repo_hash + ".zip"

file_name = shutil.make_archive(
repo_name, "zip", root_dir=temp_repo_dir
)
os.rename(file_name, temp_dir / repo_zip_name)

# We know that file_path is empty
os.makedirs(archive_path.parent, exist_ok=True)
shutil.move(temp_dir / repo_zip_name, archive_path)

print(f"Repository is cached in {archive_path}")

# NOTE: Atomically renaming the repository directory into place when the
# clone and checkout are done.
try:
shutil.move(temp_repo_dir, target_path)
except OSError as error:
shutil.rmtree(temp_dir, ignore_errors=True)

# someone beat us to it, assume it's good
if error.errno != errno.ENOTEMPTY:
raise
print(f"{target_path} already exists, skipping rename")

except Exception as error:
print(f"{error}\nFailed to clone repository '{https_url}' .")
raise error

if include_to_path:
__add_local_path_to_sys_path(target_path)

return target_path


def __add_local_path_to_sys_path(local_path: Path | str):
local_path_str = str(local_path)

Expand Down
101 changes: 55 additions & 46 deletions projects/fal/tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fal.toolkit import (
File,
clone_repository,
clone_repository_cached,
download_file,
download_model_weights,
)
Expand Down Expand Up @@ -347,24 +348,34 @@ def download_weights():
), "The model weights should be redownloaded with force=True"


def test_clone_repository(isolated_client, mock_fal_persistent_dirs):
from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR

@pytest.mark.parametrize(
"clone_fn",
[
clone_repository,
clone_repository_cached,
],
)
def test_clone_repository(isolated_client, mock_fal_persistent_dirs, clone_fn):
# https://github.com/fal-ai/isolate/tree/64b0a89c8391bd2cb3ca23cdeae01779e11aee05
EXAMPLE_REPO_URL = "https://github.com/fal-ai/isolate.git"
EXAMPLE_REPO_FIRST_COMMIT = "64b0a89c8391bd2cb3ca23cdeae01779e11aee05"
EXAMPLE_REPO_SECOND_COMMIT = "34ecbca8cc7b64719d2a5c40dd3272f8d13bc1d2"
expected_path = FAL_REPOSITORY_DIR / "isolate"
first_expected_path = (
FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}"
)
second_expected_path = (
FAL_REPOSITORY_DIR / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}"
)

# clone_repository uses FAL_REPOSITORY_DIR, clone_repository_cached uses /tmp
if clone_fn == clone_repository:
from fal.toolkit.utils.download_utils import FAL_REPOSITORY_DIR

base_dir = FAL_REPOSITORY_DIR
else:
base_dir = Path("/tmp")

expected_path = str(base_dir / "isolate")
first_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_FIRST_COMMIT[:8]}")
second_expected_path = str(base_dir / f"isolate-{EXAMPLE_REPO_SECOND_COMMIT[:8]}")

@isolated_client()
def clone_without_commit_hash():
repo_path = clone_repository(EXAMPLE_REPO_URL)
repo_path = clone_fn(EXAMPLE_REPO_URL)

return repo_path

Expand All @@ -373,14 +384,10 @@ def clone_without_commit_hash():

@isolated_client()
def clone_with_commit_hash():
first_path = clone_repository(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT
)
first_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT)
first_repo_hash = _get_git_revision_hash(first_path)

second_path = clone_repository(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT
)
second_path = clone_fn(EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_SECOND_COMMIT)

second_repo_hash = _get_git_revision_hash(second_path)

Expand Down Expand Up @@ -409,17 +416,17 @@ def clone_with_commit_hash():

@isolated_client()
def clone_with_force():
first_path = clone_repository(
first_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False
)
first_repo_stat = first_path.stat()

second_path = clone_repository(
second_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=False
)
second_repo_stat = second_path.stat()

third_path = clone_repository(
third_path = clone_fn(
EXAMPLE_REPO_URL, commit_hash=EXAMPLE_REPO_FIRST_COMMIT, force=True
)
third_repo_stat = third_path.stat()
Expand All @@ -433,32 +440,34 @@ def clone_with_force():
third_repo_stat,
)

(
first_path,
first_repo_stat,
second_path,
second_repo_stat,
third_path,
third_repo_stat,
) = clone_with_force()

assert str(first_expected_path) == str(
first_path
), "Path should be the target location"
assert str(first_expected_path) == str(
second_path
), "Path should be the target location"
assert str(first_expected_path) == str(
third_path
), "Path should be the target location"

assert (
first_repo_stat.st_mtime == second_repo_stat.st_mtime
), "The repository should not be cloned again"

assert (
first_repo_stat.st_mtime < third_repo_stat.st_mtime
), "The repository should be cloned again with force=True"
# Only test force functionality for clone_repository
if clone_fn == clone_repository:
(
first_path,
first_repo_stat,
second_path,
second_repo_stat,
third_path,
third_repo_stat,
) = clone_with_force()

assert str(first_expected_path) == str(
first_path
), "Path should be the target location"
assert str(first_expected_path) == str(
second_path
), "Path should be the target location"
assert str(first_expected_path) == str(
third_path
), "Path should be the target location"

assert (
first_repo_stat.st_mtime == second_repo_stat.st_mtime
), "The repository should not be cloned again"

assert (
first_repo_stat.st_mtime < third_repo_stat.st_mtime
), "The repository should be cloned again with force=True"


def fal_file_downloaded(file: File):
Expand Down
Loading