Skip to content

Commit 665f499

Browse files
authored
Merge pull request #59 from cyber-physical-systems-group/stubs/path-casting
Stubs/path casting
2 parents cfedc54 + 0ab5fe5 commit 665f499

File tree

5 files changed

+47
-16
lines changed

5 files changed

+47
-16
lines changed

pydentification/experiment/storage/code.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import uuid
44
from pathlib import Path
55

6+
from pydentification.stubs import cast_to_path
7+
68

79
def _load_gitignore() -> set[str]:
810
"""Load .gitignore from default name and root directory as set"""
@@ -31,6 +33,7 @@ def _skip_subdir(current: Path, archive_path: Path, forbidden_paths: set[str]) -
3133
return False
3234

3335

36+
@cast_to_path
3437
def save_code_snapshot(
3538
name: str,
3639
source_dir: str | Path,
@@ -49,12 +52,6 @@ def save_code_snapshot(
4952
:param accept_suffix: set of suffixes to include in the archive
5053
:param use_gitignore: whether to use .gitignore file in the source directory for filter_prefix
5154
"""
52-
if isinstance(source_dir, str):
53-
source_dir = Path(source_dir)
54-
55-
if isinstance(target_dir, str):
56-
target_dir = Path(target_dir)
57-
5855
source_dir = Path(source_dir).resolve() # ensure absolute path
5956
snapshot_path = target_dir / name
6057
temp_dir = target_dir / str(uuid.uuid4()) # create temp dir with unique name for copying files

pydentification/experiment/storage/compose.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pathlib import Path
88
from typing import Any, Callable
99

10+
from pydentification.stubs import cast_to_path
11+
1012

1113
class ReplaceSourceCode:
1214
"""
@@ -55,6 +57,7 @@ def _load_model_and_parameters(path: str | Path, name: str, parameters: dict[str
5557
return model_fn(**parameters)
5658

5759

60+
@cast_to_path
5861
def compose_model(
5962
path: str | Path,
6063
name: str = "model_fn",
@@ -71,9 +74,6 @@ def compose_model(
7174
:param source: filesystem Path to the ZIP file with source code
7275
if None imports are attempted from the current working directory.
7376
"""
74-
if isinstance(source, str):
75-
source = Path(source)
76-
7777
if parameters is not None:
7878
with open(parameters, "r") as f:
7979
parameters = json.load(f)

pydentification/experiment/storage/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
from safetensors.torch import save_model
88

9+
from pydentification.stubs import cast_to_path
10+
911

1012
def save_torch(path: Path, model: torch.nn.Module, method: Literal["pt", "safetensors"] = "safetensors"):
1113
if method == "safetensors":
@@ -21,6 +23,7 @@ def save_json(path: Path, data: dict):
2123
json.dump(data, f) # type: ignore
2224

2325

26+
@cast_to_path
2427
def save_lightning(
2528
path: str | Path,
2629
model: pl.LightningModule,
@@ -33,10 +36,7 @@ def save_lightning(
3336
:param method: method of saving the model, either "pt" or "safetensors"
3437
:param save_hparams: whether to save hyperparameters in a JSON file
3538
"""
36-
if isinstance(path, str):
37-
path = Path(path)
3839
path.mkdir(parents=True, exist_ok=True)
39-
4040
save_torch(path / f"trained-model.{method}", model=model.module, method=method) # save only the model
4141

4242
if save_hparams:

pydentification/experiment/storage/sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import wandb
44

5+
from pydentification.stubs import cast_to_path
56

7+
8+
@cast_to_path
69
def save_to_wandb(path: str | Path):
710
"""Save all files from directory to W&B"""
8-
if isinstance(path, str):
9-
path = Path(path)
10-
1111
for file in path.rglob("*"):
1212
wandb.save(str(file)) # save file one by one

pydentification/stubs.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,37 @@
1-
from typing import Callable
1+
from functools import wraps
2+
from pathlib import Path
3+
from typing import Any, Callable, Union, get_type_hints
24

35
Print = Callable[[str], None]
6+
7+
8+
def cast_to_path(func: Callable):
9+
"""
10+
Decorator to cast string arguments to Path in the function signature.
11+
Uses type hints to determine which arguments should be cast.
12+
13+
:note: do not use with functions that have missing type hints
14+
"""
15+
16+
def cast(arg: Any, hint: Any) -> Any:
17+
if hint == Union[str, Path]:
18+
return Path(arg)
19+
return arg
20+
21+
@wraps(func)
22+
def wrapper(*args, **kwargs):
23+
hints = get_type_hints(func)
24+
new_args = []
25+
new_kwargs = {}
26+
27+
if len(args) + len(kwargs) > len(hints):
28+
raise TypeError("Some arguments are missing type hints!")
29+
30+
for arg, (name, hint) in zip(args, hints.items()):
31+
new_args.append(cast(arg, hint))
32+
for name, arg in kwargs.items():
33+
new_kwargs[name] = cast(arg, hint=hints.get(name))
34+
35+
return func(*new_args, **new_kwargs)
36+
37+
return wrapper

0 commit comments

Comments
 (0)