Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 316 additions & 1 deletion projects/fal/src/fal/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import asyncio
import inspect
import multiprocessing
import os
import secrets
import shutil
import subprocess
import sys
import threading
import time
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import asynccontextmanager, suppress
Expand Down Expand Up @@ -84,6 +89,13 @@
BasicConfig = Dict[str, Any]
_UNSET = object()

DUMP_READY_PATH = "/tmp/criu_dump_ready"
RESTORE_READY_PATH = "/tmp/criu_restore_ready"
SNAPSHOT_ROOT_DIR = "/data/.fal/snapshots"
SNAPSHOT_IMAGES_DIR = f"{SNAPSHOT_ROOT_DIR}/images"
CRIU_BINARY = f"{SNAPSHOT_ROOT_DIR}/criu/criu/criu"
CRIU_PIDFILE_PATH = "/tmp/criu_pidfile"

SERVE_REQUIREMENTS = [
f"fastapi=={fastapi_version}",
f"pydantic=={pydantic_version}",
Expand Down Expand Up @@ -1208,6 +1220,8 @@ class RouteSignature(NamedTuple):

class BaseServable:
version: ClassVar[str] = "unknown"
snapshot: ClassVar[bool] = False
snapshot_key: ClassVar[str | None] = None

def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
raise NotImplementedError
Expand Down Expand Up @@ -1376,7 +1390,7 @@ def openapi(self) -> dict[str, Any]:
"Failed to generate OpenAPI metadata for function"
) from e

async def serve(self) -> None:
async def _serve_uvicorn(self) -> None:
from prometheus_client import Gauge
from starlette_exporter import handle_metrics

Expand Down Expand Up @@ -1415,6 +1429,307 @@ async def _serve() -> None:

await _serve()

async def serve(self) -> None:
if not self.snapshot:
await self._serve_uvicorn()
return

snapshot_key = self.snapshot_key or "default"
latest_dump_dir = _find_latest_snapshot_dir(snapshot_key)
if latest_dump_dir is not None:
await _restore_from_snapshot_dump(latest_dump_dir)
return

dummy_process = subprocess.Popen(["sleep", "1000000"])
ctx = _get_multiprocessing_context()
process = ctx.Process(target=_run_servable_in_subprocess, args=(self,))
process.start()
self._serve_process = process

await _wait_for_snapshot_dump_ready(
process,
dummy_process,
snapshot_key,
)

try:
while process.is_alive():
await asyncio.sleep(0.25)
except asyncio.CancelledError:
_terminate_process(process)
_terminate_subprocess(dummy_process)
raise
finally:
if process.is_alive():
_terminate_process(process)
_terminate_subprocess(dummy_process)

if process.exitcode not in (0, None):
raise FalServerlessException(
f"Uvicorn subprocess exited with code {process.exitcode}"
)


def _get_multiprocessing_context() -> multiprocessing.context.BaseContext:
start_methods = multiprocessing.get_all_start_methods()
if (
"fork" in start_methods
and threading.current_thread() is threading.main_thread()
):
return multiprocessing.get_context("fork")
return multiprocessing.get_context()


def _run_servable_in_subprocess(servable: BaseServable) -> None:
_install_snapshot_dependencies()
asyncio.run(servable._serve_uvicorn())


def _terminate_process(process: multiprocessing.Process) -> None:
process.terminate()
process.join(timeout=5)


def _terminate_subprocess(process: subprocess.Popen) -> None:
if process.poll() is None:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=5)


async def _wait_for_snapshot_dump_ready(
process: multiprocessing.Process,
dummy_process: subprocess.Popen,
snapshot_key: str,
) -> None:
while process.is_alive():
if os.path.exists(DUMP_READY_PATH):
print(
f"[snapshot] Detected {DUMP_READY_PATH}; creating {RESTORE_READY_PATH}",
flush=True,
)
temp_dump_dir = _build_snapshot_dir(snapshot_key) + ".tmp"
if os.path.exists(temp_dump_dir):
shutil.rmtree(temp_dump_dir)
os.makedirs(temp_dump_dir, exist_ok=True)
stdout_ino = os.fstat(1).st_ino
stderr_ino = os.fstat(2).st_ino
with open(
os.path.join(temp_dump_dir, "stdout_pipe_st_ino"), "w", encoding="utf-8"
) as stdout_file:
stdout_file.write(str(stdout_ino))
with open(
os.path.join(temp_dump_dir, "stderr_pipe_st_ino"), "w", encoding="utf-8"
) as stderr_file:
stderr_file.write(str(stderr_ino))
dump_start = time.monotonic()
subprocess.check_call(
[
CRIU_BINARY,
"dump",
"-t",
str(process.pid),
"-D",
temp_dump_dir,
"--libdir",
f"{SNAPSHOT_ROOT_DIR}/criu/plugins/cuda",
"--shell-job",
"--unprivileged",
"--tcp-close",
"--leave-running",
"--skip-in-flight",
],
env={
**os.environ,
"PATH": f"{SNAPSHOT_ROOT_DIR}:{os.environ.get('PATH', '')}",
},
)
dump_elapsed = time.monotonic() - dump_start
print(
f"[snapshot] Dump completed in {dump_elapsed:.2f}s.",
flush=True,
)
snapshot_dir = temp_dump_dir[:-4]
print(f"[snapshot] Saving snapshot to {snapshot_dir}", flush=True)
os.makedirs(os.path.dirname(snapshot_dir), exist_ok=True)
shutil.move(temp_dump_dir, snapshot_dir)
with open(RESTORE_READY_PATH, "w", encoding="utf-8"):
pass
_terminate_subprocess(dummy_process)
return
await asyncio.sleep(0.1)

if not os.path.exists(DUMP_READY_PATH):
print(
"[snapshot] Subprocess exited before dump ready signal.",
flush=True,
)
_terminate_subprocess(dummy_process)
raise FalServerlessException(
"Uvicorn subprocess exited before snapshot was ready"
)


async def _restore_from_snapshot_dump(dump_dir: str) -> None:
_install_snapshot_dependencies()
env = {
**os.environ,
"PATH": f"{SNAPSHOT_ROOT_DIR}:{os.environ.get('PATH', '')}",
}
if os.path.exists(CRIU_PIDFILE_PATH):
os.remove(CRIU_PIDFILE_PATH)
with open(RESTORE_READY_PATH, "w", encoding="utf-8"):
pass
print(
f"[snapshot] Restoring from existing snapshot at {dump_dir}.",
flush=True,
)
with open(
os.path.join(dump_dir, "stdout_pipe_st_ino"), encoding="utf-8"
) as stdout_file:
stdout_ino = stdout_file.read().strip()
with open(
os.path.join(dump_dir, "stderr_pipe_st_ino"), encoding="utf-8"
) as stderr_file:
stderr_ino = stderr_file.read().strip()
restore_start = time.monotonic()
try:
subprocess.check_call(
[
CRIU_BINARY,
"restore",
"-D",
dump_dir,
"--libdir",
f"{SNAPSHOT_ROOT_DIR}/criu/plugins/cuda",
"--shell-job",
"--unprivileged",
"--tcp-close",
"--leave-running",
"--restore-sibling",
"--restore-detached",
"--inherit-fd",
f"fd[1]:pipe:[{stdout_ino}]",
"--inherit-fd",
f"fd[2]:pipe:[{stderr_ino}]",
"--pidfile",
CRIU_PIDFILE_PATH,
],
env=env,
)
except subprocess.CalledProcessError:
restore_elapsed = time.monotonic() - restore_start
print(
f"[snapshot] Restore failed after {restore_elapsed:.2f}s.",
flush=True,
)
raise

restore_elapsed = time.monotonic() - restore_start
print(
f"[snapshot] Restore completed in {restore_elapsed:.2f}s.",
flush=True,
)
if not os.path.exists(CRIU_PIDFILE_PATH):
raise FalServerlessException(f"CRIU pidfile not found at {CRIU_PIDFILE_PATH}")
with open(CRIU_PIDFILE_PATH, encoding="utf-8") as pidfile:
restored_pid = int(pidfile.read().strip() or 0)
if restored_pid <= 0:
raise FalServerlessException("CRIU pidfile did not contain a valid PID")
await _wait_for_restored_pid(restored_pid)


async def _wait_for_restored_pid(pid: int) -> None:
while True:
if not _pid_is_alive(pid):
raise FalServerlessException(f"Restored process {pid} exited unexpectedly")
await asyncio.sleep(0.5)


def _pid_is_alive(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
else:
return True


def _build_snapshot_dir(snapshot_key: str) -> str:
timestamp = time.strftime("%Y%m%d-%H%M%S")
rand_key = secrets.token_hex(4)
return os.path.join(SNAPSHOT_IMAGES_DIR, snapshot_key, f"{timestamp}-{rand_key}")


def _find_latest_snapshot_dir(snapshot_key: str) -> str | None:
base_dir = os.path.join(SNAPSHOT_IMAGES_DIR, snapshot_key)
if not os.path.isdir(base_dir):
return None
candidates = [
entry
for entry in os.listdir(base_dir)
if os.path.isdir(os.path.join(base_dir, entry)) and not entry.endswith(".tmp")
]
if not candidates:
return None
latest = sorted(candidates)[-1]
return os.path.join(base_dir, latest)


def _install_snapshot_dependencies() -> None:
subprocess.check_call(["apt-get", "update"])
subprocess.check_call(
[
"apt-get",
"install",
"-y",
"wget",
"libprotobuf-dev",
"libprotobuf-c-dev",
"protobuf-c-compiler",
"protobuf-compiler",
"python3-protobuf",
"libnl-3-dev",
"libnet-dev",
"libcap-dev",
]
)
snapshot_dir = SNAPSHOT_ROOT_DIR
os.makedirs(snapshot_dir, exist_ok=True)
cuda_checkpoint_path = os.path.join(snapshot_dir, "cuda-checkpoint")
if not os.path.exists(cuda_checkpoint_path):
subprocess.check_call(
[
"wget",
"https://github.com/NVIDIA/cuda-checkpoint/raw/refs/heads/main/bin/x86_64_Linux/cuda-checkpoint",
"-O",
cuda_checkpoint_path,
]
)
os.chmod(cuda_checkpoint_path, 0o755)
criu_dir = os.path.join(snapshot_dir, "criu")
if not os.path.exists(criu_dir):
temp_criu_dir = f"{criu_dir}.tmp-{secrets.token_hex(4)}"
subprocess.check_call(
[
"git",
"clone",
"--depth",
"1",
"--branch",
"ruslan/cuda-seccomp",
"https://github.com/efiop/criu",
temp_criu_dir,
]
)
subprocess.check_call(["make"], cwd=temp_criu_dir)
shutil.move(temp_criu_dir, criu_dir)


class ServeWrapper(BaseServable):
_func: Callable
Expand Down
Loading
Loading