Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ celerybeat.pid
.venv
env/
venv/
venv*/
ENV/
env.bak/
venv.bak/
Expand Down
18 changes: 0 additions & 18 deletions pydentification/experiment/defaults/report.py

This file was deleted.

11 changes: 0 additions & 11 deletions pydentification/experiment/defaults/save.py

This file was deleted.

12 changes: 0 additions & 12 deletions pydentification/experiment/defaults/train.py

This file was deleted.

71 changes: 71 additions & 0 deletions pydentification/experiment/dumper/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import shutil
import uuid
from pathlib import Path

PYTHON_EXTENSIONS = frozenset({".py", ".json", ".txt", ".md", ".yaml", ".yml", ".toml", ".ini"})
DEFAULT_FORBIDDEN_PREFIX = frozenset({"venv", ".ipynb_checkpoints", "__pycache__", ".git", ".pytest_cache"})


def _load_gitignore() -> set[str]:
"""Load .gitignore from default name and root directory as set"""

def not_comment(line: str) -> bool:
return not (line.startswith("#") or line.isspace() or not line)

gitignore = Path(".gitignore")
if not gitignore.exists():
return set()

with gitignore.open("r") as f:
return set(filter(not_comment, f.read().splitlines()))


def _skip_subdir(current: Path, archive_path: Path, forbidden_paths: frozenset[str]) -> bool:
# prevent copying the temp directory, where the archive with source code is build
if str(archive_path.absolute()) == current:
return True
# prevent copying the parent directory of the temp directory
elif archive_path in current.parents:
return True
# prevent copying the forbidden paths from defaults and .gitignore
elif any(part.startswith(prefix) for prefix in forbidden_paths for part in current.parts):
return True
return False


def save_code_snapshot(name: str, source_dir: str | Path):
"""Save only text-based files in a ZIP archive, excluding binary data files."""

if isinstance(source_dir, str):
source_dir = Path(source_dir)

source_dir = Path(source_dir).resolve() # ensure absolute path
snapshot_filename = f"source_code_{name}"
temp_dir = Path(f"temp_code_snapshot_{uuid.uuid4()}") # append random UUID to avoid conflicts

gitignore = _load_gitignore()
forbidden = DEFAULT_FORBIDDEN_PREFIX | gitignore

if temp_dir.exists():
shutil.rmtree(temp_dir)

temp_dir.mkdir(parents=True, exist_ok=True)

for root, dirs, files in os.walk(source_dir):
root_path = Path(root)
if _skip_subdir(root_path, temp_dir, forbidden):
dirs.clear() # prevent descending into this directory
continue # skip to the next directory

for file in files:
file_path = root_path / file
if file_path.suffix in PYTHON_EXTENSIONS:
relative_path = file_path.relative_to(source_dir)
dest_path = temp_dir / relative_path

dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(file_path, dest_path)

shutil.make_archive(snapshot_filename, format="zip", root_dir=temp_dir) # archive the directory
shutil.rmtree(temp_dir)
45 changes: 45 additions & 0 deletions pydentification/experiment/dumper/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import json
from pathlib import Path
from typing import Literal

import lightning.pytorch as pl
import torch
import wandb
from safetensors.torch import save_model


def save_torch(path: Path, model: torch.nn.Module, method: Literal["pt", "safetensors"] = "safetensors"):
if method == "safetensors":
save_model(model, path)
elif method == "pt":
torch.save(model.state_dict(), path) # saves only torch
else:
raise ValueError(f"Unknown method: {method}!")


def save_json(path: Path, data: dict):
with path.open("w") as f:
json.dump(data, f) # type: ignore


def save_fn(
name: str,
model: pl.LightningModule,
method: Literal["pt", "safetensors"] = "safetensors",
save_hparams: bool = False,
):
"""
:param name: name of the parent directory with the model and settings
:param model: PyTorch model
:param method: method of saving the model, either "pt" or "safetensors"
:param save_hparams: whether to save hyperparameters in a JSON file
"""
path = Path(f"models/{name}")
path.mkdir(parents=True, exist_ok=True)

save_torch(path / f"trained-model.{method}", model=model.module, method=method) # save only the model

if save_hparams:
save_json((path / "hparams.json"), model.hparams or {})

wandb.save(path)
Empty file.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ torch==2.4.0
plotly==5.24.1
wandb==0.18.3
h5py==3.12.1
safetensors==0.5.2
10 changes: 0 additions & 10 deletions tests/test_data/utils.py

This file was deleted.