Skip to content

Commit 040c1f2

Browse files
authored
[iris] Compress checkpoints with zstd, prune old ones, improve restart UX (#4143)
Checkpoint sqlite3 files are now compressed with zstandard (level 3) before upload to remote storage. Downloads prefer .zst but fall back to uncompressed files for backward compatibility with checkpoints written before this change. Old checkpoints older than 3 days are pruned best-effort after each write. The controller restart command gains --skip-checkpoint and --checkpoint-timeout (default 5 minutes, was hardcoded 60s) flags, with progress feedback printed before the RPC call.
1 parent b140fd2 commit 040c1f2

5 files changed

Lines changed: 250 additions & 101 deletions

File tree

lib/iris/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"uvicorn[standard]>=0.23.0",
2626
"duckdb>=1.0.0",
2727
"pyarrow>=19.0.0",
28+
"zstandard>=0.22.0",
2829
]
2930

3031
[project.optional-dependencies]

lib/iris/src/iris/cli/cluster.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,17 @@ def controller_checkpoint(ctx, stop: bool):
609609

610610

611611
@controller.command("restart")
612+
@click.option(
613+
"--skip-checkpoint",
614+
is_flag=True,
615+
default=False,
616+
help="Skip the pre-restart checkpoint (use if checkpoint is timing out).",
617+
)
618+
@click.option(
619+
"--checkpoint-timeout", type=int, default=300, show_default=True, help="Checkpoint RPC timeout in seconds."
620+
)
612621
@click.pass_context
613-
def controller_restart(ctx):
622+
def controller_restart(ctx, skip_checkpoint: bool, checkpoint_timeout: int):
614623
"""Restart controller with state preservation (remote platforms only).
615624
616625
Takes a checkpoint, builds fresh images, stops the controller, and starts
@@ -653,15 +662,22 @@ def controller_restart(ctx):
653662
return
654663

655664
# Checkpoint
656-
client = cluster_connect.ControllerServiceClientSync(controller_url)
657-
try:
658-
resp = client.begin_checkpoint(cluster_pb2.Controller.BeginCheckpointRequest(), timeout_ms=60_000)
659-
except Exception as e:
660-
click.echo(f"Checkpoint failed: {e}", err=True)
661-
raise SystemExit(1) from e
662-
finally:
663-
client.close()
664-
click.echo(f"Checkpoint: {resp.checkpoint_path} ({resp.job_count} jobs, {resp.worker_count} workers)")
665+
if skip_checkpoint:
666+
click.echo("Skipping pre-restart checkpoint.")
667+
else:
668+
click.echo(f"Taking checkpoint (timeout {checkpoint_timeout}s)...")
669+
client = cluster_connect.ControllerServiceClientSync(controller_url)
670+
try:
671+
resp = client.begin_checkpoint(
672+
cluster_pb2.Controller.BeginCheckpointRequest(),
673+
timeout_ms=checkpoint_timeout * 1000,
674+
)
675+
except Exception as e:
676+
click.echo(f"Checkpoint failed: {e}", err=True)
677+
raise SystemExit(1) from e
678+
finally:
679+
client.close()
680+
click.echo(f"Checkpoint: {resp.checkpoint_path} ({resp.job_count} jobs, {resp.worker_count} workers)")
665681

666682
# Build fresh images so the new controller VM gets the latest code
667683
_pin_latest_images(config)

lib/iris/src/iris/cluster/controller/checkpoint.py

Lines changed: 153 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
to remote storage and restoring the DB file from a remote checkpoint.
88
99
Checkpoint layout (remote):
10-
{remote_state_dir}/controller-state/{epoch_ms}/controller.sqlite3
11-
{remote_state_dir}/controller-state/{epoch_ms}/auth.sqlite3
10+
{remote_state_dir}/controller-state/{epoch_ms}/controller.sqlite3.zst
11+
{remote_state_dir}/controller-state/{epoch_ms}/auth.sqlite3.zst
12+
13+
Files are compressed with zstandard (level 3) before upload. On download,
14+
compressed (.zst) files are preferred; uncompressed files are accepted as
15+
a fallback for checkpoints written before compression was added.
1216
1317
Restore locates the most recent timestamped directory by listing, or
14-
uses an explicit checkpoint directory path. The "latest" alias convention
15-
has been removed.
18+
uses an explicit checkpoint directory path.
1619
1720
Autoscaler/scaling-group reconciliation lives in autoscaler.py and
1821
scaling_group.py respectively.
@@ -28,12 +31,35 @@
2831
from pathlib import Path
2932

3033
import fsspec.core
34+
import zstandard
3135

3236
from iris.cluster.controller.db import JOBS, TASKS, WORKERS, ControllerDB
33-
from iris.time_utils import Timestamp
37+
from iris.time_utils import Duration, Timestamp
3438

3539
logger = logging.getLogger(__name__)
3640

41+
ZSTD_LEVEL = 3
42+
DEFAULT_PRUNE_AGE = Duration.from_hours(3 * 24) # 3 days
43+
44+
45+
# ---------------------------------------------------------------------------
46+
# Compression helpers
47+
# ---------------------------------------------------------------------------
48+
49+
50+
def _compress_zstd(src: Path, dst: Path) -> None:
51+
"""Compress *src* to *dst* using zstandard at ``ZSTD_LEVEL``."""
52+
cctx = zstandard.ZstdCompressor(level=ZSTD_LEVEL)
53+
with open(src, "rb") as f_in, open(dst, "wb") as f_out:
54+
cctx.copy_stream(f_in, f_out)
55+
56+
57+
def _decompress_zstd(src: Path, dst: Path) -> None:
58+
"""Decompress a zstd-compressed *src* to *dst*."""
59+
dctx = zstandard.ZstdDecompressor()
60+
with open(src, "rb") as f_in, open(dst, "wb") as f_out:
61+
dctx.copy_stream(f_in, f_out)
62+
3763

3864
# ---------------------------------------------------------------------------
3965
# Checkpoint result
@@ -77,45 +103,52 @@ def write_checkpoint(
77103
db: ControllerDB,
78104
remote_state_dir: str,
79105
) -> tuple[str, CheckpointResult]:
80-
"""Write a timestamped SQLite checkpoint to a remote directory.
106+
"""Write a timestamped, zstd-compressed SQLite checkpoint to remote storage.
81107
82108
Layout:
83-
{remote_state_dir}/controller-state/{epoch_ms}/controller.sqlite3
84-
{remote_state_dir}/controller-state/{epoch_ms}/auth.sqlite3
109+
{remote_state_dir}/controller-state/{epoch_ms}/controller.sqlite3.zst
110+
{remote_state_dir}/controller-state/{epoch_ms}/auth.sqlite3.zst
85111
112+
Old checkpoints (> 3 days) are pruned best-effort after the write.
86113
Returns the remote directory path and a summary of checkpoint contents.
87114
"""
88115
created_at = Timestamp.now()
89116
prefix = remote_state_dir.rstrip("/") + "/controller-state"
90117
checkpoint_dir = f"{prefix}/{created_at.epoch_ms()}"
91118

92-
# Backup main DB
93-
main_remote = f"{checkpoint_dir}/{ControllerDB.DB_FILENAME}"
119+
# Backup main DB (compressed)
120+
main_remote = f"{checkpoint_dir}/{ControllerDB.DB_FILENAME}.zst"
94121
tmp_dir = db.db_path.parent
95122
tmp_dir.mkdir(parents=True, exist_ok=True)
96123
fd, tmp_name = tempfile.mkstemp(suffix=".sqlite3", dir=tmp_dir)
97124
os.close(fd)
98125
tmp_path = Path(tmp_name)
126+
tmp_zst = tmp_path.with_suffix(".sqlite3.zst")
99127
try:
100128
db.backup_to(tmp_path)
101-
_fsspec_copy(str(tmp_path), main_remote)
129+
_compress_zstd(tmp_path, tmp_zst)
130+
_fsspec_copy(str(tmp_zst), main_remote)
102131
logger.info("checkpoint main DB uploaded to %s", main_remote)
103132
finally:
104133
tmp_path.unlink(missing_ok=True)
134+
tmp_zst.unlink(missing_ok=True)
105135

106-
# Backup auth DB
136+
# Backup auth DB (compressed)
107137
auth_path = db.auth_db_path
108138
if auth_path.exists():
109-
auth_remote = f"{checkpoint_dir}/{ControllerDB.AUTH_DB_FILENAME}"
139+
auth_remote = f"{checkpoint_dir}/{ControllerDB.AUTH_DB_FILENAME}.zst"
110140
fd2, tmp_name2 = tempfile.mkstemp(suffix=".sqlite3", dir=tmp_dir)
111141
os.close(fd2)
112142
tmp_path2 = Path(tmp_name2)
143+
tmp_zst2 = tmp_path2.with_suffix(".sqlite3.zst")
113144
try:
114145
_backup_sqlite_file(auth_path, tmp_path2)
115-
_fsspec_copy(str(tmp_path2), auth_remote)
146+
_compress_zstd(tmp_path2, tmp_zst2)
147+
_fsspec_copy(str(tmp_zst2), auth_remote)
116148
logger.info("checkpoint auth DB uploaded to %s", auth_remote)
117149
finally:
118150
tmp_path2.unlink(missing_ok=True)
151+
tmp_zst2.unlink(missing_ok=True)
119152

120153
with db.snapshot() as snapshot:
121154
job_count = snapshot.count(JOBS)
@@ -127,31 +160,54 @@ def write_checkpoint(
127160
task_count=task_count,
128161
worker_count=worker_count,
129162
)
163+
164+
# Best-effort pruning of old checkpoints
165+
try:
166+
pruned = prune_old_checkpoints(remote_state_dir)
167+
if pruned:
168+
logger.info("Pruned %d old checkpoint(s)", pruned)
169+
except Exception:
170+
logger.warning("Failed to prune old checkpoints", exc_info=True)
171+
130172
return checkpoint_dir, result
131173

132174

133-
def _find_latest_checkpoint_dir(remote_state_dir: str) -> str | None:
134-
"""Find the most recent timestamped checkpoint directory.
175+
def _reconstruct_uri(remote_state_dir: str, fs_path: str) -> str:
176+
"""Reconstruct a full URI from a remote_state_dir (for its scheme) and an fs_path."""
177+
scheme = remote_state_dir.split("://", 1)[0] if "://" in remote_state_dir else "file"
178+
return f"{scheme}://{fs_path.rstrip('/')}"
135179

136-
Lists {remote_state_dir}/controller-state/ for subdirectories with
137-
numeric names (epoch_ms), returns the path to the newest one.
180+
181+
def _list_checkpoint_entries(remote_state_dir: str) -> list[str] | None:
182+
"""List immediate children of {remote_state_dir}/controller-state/.
183+
184+
Returns None if the directory does not exist.
138185
"""
139186
prefix = remote_state_dir.rstrip("/") + "/controller-state"
140187
fs, fs_path = fsspec.core.url_to_fs(prefix)
141188

142189
if not fs.exists(fs_path):
143190
return None
144191

145-
# List immediate children — each is a timestamp directory
146192
try:
147-
entries = fs.ls(fs_path, detail=False)
193+
return fs.ls(fs_path, detail=False)
148194
except FileNotFoundError:
149195
return None
150196

197+
198+
def _find_latest_checkpoint_dir(remote_state_dir: str) -> str | None:
199+
"""Find the most recent timestamped checkpoint directory.
200+
201+
Lists {remote_state_dir}/controller-state/ for subdirectories with
202+
numeric names (epoch_ms), returns the path to the newest one.
203+
"""
204+
entries = _list_checkpoint_entries(remote_state_dir)
205+
if entries is None:
206+
return None
207+
151208
# Filter to numeric directory names (epoch_ms timestamps)
152209
timestamp_dirs: list[tuple[int, str]] = []
153210
for entry in entries:
154-
# entry may be "bucket/path/controller-state/1234567890" or similar
155211
basename = entry.rstrip("/").rsplit("/", 1)[-1]
156212
if basename.isdigit():
157213
timestamp_dirs.append((int(basename), entry))
@@ -162,9 +218,64 @@ def _find_latest_checkpoint_dir(remote_state_dir: str) -> str | None:
162218
# Return the most recent (highest timestamp)
163219
timestamp_dirs.sort(reverse=True)
164220
_, latest_path = timestamp_dirs[0]
165-
# Reconstruct as a proper URI using the original scheme
166-
scheme = remote_state_dir.split("://", 1)[0] if "://" in remote_state_dir else "file"
167-
return f"{scheme}://{latest_path.rstrip('/')}"
221+
return _reconstruct_uri(remote_state_dir, latest_path)
222+
223+
224+
def _pick_remote(zst_path: str, plain_path: str) -> tuple[str | None, bool]:
225+
"""Return (remote_path, is_compressed) preferring the .zst variant."""
226+
fs, fs_path = fsspec.core.url_to_fs(zst_path)
227+
if fs.exists(fs_path):
228+
return zst_path, True
229+
fs2, fs_path2 = fsspec.core.url_to_fs(plain_path)
230+
if fs2.exists(fs_path2):
231+
return plain_path, False
232+
return None, False
233+
234+
235+
def _download_one(remote: str, local: Path, *, compressed: bool) -> None:
236+
"""Download a single file, decompressing if needed. Uses atomic rename."""
237+
if compressed:
238+
tmp_zst = local.with_suffix(".download.zst.tmp")
239+
_fsspec_copy(remote, str(tmp_zst))
240+
tmp_plain = local.with_suffix(".download.tmp")
241+
try:
242+
_decompress_zstd(tmp_zst, tmp_plain)
243+
finally:
244+
tmp_zst.unlink(missing_ok=True)
245+
tmp_plain.rename(local)
246+
else:
247+
tmp_path = local.with_suffix(".download.tmp")
248+
_fsspec_copy(remote, str(tmp_path))
249+
tmp_path.rename(local)
250+
251+
252+
def prune_old_checkpoints(
253+
remote_state_dir: str,
254+
max_age: Duration = DEFAULT_PRUNE_AGE,
255+
) -> int:
256+
"""Delete checkpoint directories older than *max_age*.
257+
258+
Returns the number of directories pruned.
259+
"""
260+
entries = _list_checkpoint_entries(remote_state_dir)
261+
if entries is None:
262+
return 0
263+
264+
cutoff_ms = Timestamp.now().add_ms(-max_age.to_ms()).epoch_ms()
265+
fs, _ = fsspec.core.url_to_fs(remote_state_dir)
266+
pruned = 0
267+
for entry in entries:
268+
basename = entry.rstrip("/").rsplit("/", 1)[-1]
269+
if not basename.isdigit():
270+
continue
271+
if int(basename) < cutoff_ms:
272+
try:
273+
fs.rm(entry, recursive=True)
274+
logger.info("Pruned old checkpoint: %s", entry)
275+
pruned += 1
276+
except Exception:
277+
logger.warning("Failed to prune checkpoint: %s", entry, exc_info=True)
278+
return pruned
168279

169280

170281
def download_checkpoint_to_local(
@@ -174,9 +285,12 @@ def download_checkpoint_to_local(
174285
) -> bool:
175286
"""Download a remote checkpoint directory to a local db_dir.
176287
177-
Looks for controller.sqlite3 and auth.sqlite3 in the checkpoint
178-
directory. If ``checkpoint_dir`` is not provided, finds the most
179-
recent timestamped checkpoint under ``remote_state_dir/controller-state/``.
288+
Looks for controller.sqlite3(.zst) and auth.sqlite3(.zst) in the
289+
checkpoint directory. Compressed files are preferred; uncompressed
290+
files are accepted as a fallback.
291+
292+
If ``checkpoint_dir`` is not provided, finds the most recent
293+
timestamped checkpoint under ``remote_state_dir/controller-state/``.
180294
181295
Returns True if a checkpoint was downloaded, False if none found.
182296
"""
@@ -189,30 +303,28 @@ def download_checkpoint_to_local(
189303
return False
190304
source_dir = found
191305

192-
# Check that the main DB exists in the source directory
193-
main_source = f"{source_dir}/{ControllerDB.DB_FILENAME}"
194-
fs, fs_path = fsspec.core.url_to_fs(main_source)
195-
if not fs.exists(fs_path):
196-
logger.info("No remote checkpoint at %s, starting fresh", main_source)
306+
# Prefer compressed (.zst), fall back to uncompressed for old checkpoints
307+
main_zst = f"{source_dir}/{ControllerDB.DB_FILENAME}.zst"
308+
main_plain = f"{source_dir}/{ControllerDB.DB_FILENAME}"
309+
main_source, compressed = _pick_remote(main_zst, main_plain)
310+
if main_source is None:
311+
logger.info("No remote checkpoint at %s, starting fresh", source_dir)
197312
return False
198313

199314
local_db_dir.mkdir(parents=True, exist_ok=True)
200315

201316
# Download main DB
202317
local_main = local_db_dir / ControllerDB.DB_FILENAME
203-
tmp_path = local_main.with_suffix(".download.tmp")
204-
_fsspec_copy(main_source, str(tmp_path))
205-
tmp_path.rename(local_main)
318+
_download_one(main_source, local_main, compressed=compressed)
206319
logger.info("Downloaded checkpoint from %s to %s", main_source, local_main)
207320

208321
# Download auth DB if available
209-
auth_source = f"{source_dir}/{ControllerDB.AUTH_DB_FILENAME}"
210-
auth_fs, auth_fs_path = fsspec.core.url_to_fs(auth_source)
211-
if auth_fs.exists(auth_fs_path):
322+
auth_zst = f"{source_dir}/{ControllerDB.AUTH_DB_FILENAME}.zst"
323+
auth_plain = f"{source_dir}/{ControllerDB.AUTH_DB_FILENAME}"
324+
auth_source, auth_compressed = _pick_remote(auth_zst, auth_plain)
325+
if auth_source is not None:
212326
local_auth = local_db_dir / ControllerDB.AUTH_DB_FILENAME
213-
auth_tmp = local_auth.with_suffix(".download.tmp")
214-
_fsspec_copy(auth_source, str(auth_tmp))
215-
auth_tmp.rename(local_auth)
327+
_download_one(auth_source, local_auth, compressed=auth_compressed)
216328
logger.info("Downloaded auth checkpoint from %s to %s", auth_source, local_auth)
217329

218330
return True

0 commit comments

Comments
 (0)