Skip to content

Commit 940c097

Browse files
committed
Remove torch and numpy
1 parent 869d9e9 commit 940c097

8 files changed

Lines changed: 91 additions & 455 deletions

File tree

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2024 Marvin Sextro
3+
Copyright (c) 2025 Marvin Sextro
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

example/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ml_project_template.runs import Job, Run, SlurmParams, SweepJob
44
from ml_project_template.wandb import WandBRun
55

6-
RunConfig = builds(Run, seed=42, wandb=None, job=None)
6+
RunConfig = builds(Run, seed=None, wandb=None, job=None)
77

88
SlurmParamsConfig = builds(
99
SlurmParams,
@@ -18,6 +18,6 @@
1818

1919
JobConfig = builds(Job, slurm_params=SlurmParamsConfig)
2020

21-
SweepConfig = builds(SweepJob, num_workers=2, parameters={"cfg.seed": [42, 1337]}, builds_bases=(JobConfig,))
21+
SweepConfig = builds(SweepJob, num_workers=2, parameters={"foo": [42, 1337]}, builds_bases=(JobConfig,))
2222

2323
WandBConfig = builds(WandBRun, group=None, mode="online")

ml_project_template/config.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77
from hydra_zen.third_party.pydantic import pydantic_parser
88
from omegaconf import DictConfig, OmegaConf
99

10-
from ml_project_template.utils import ConfigKeys, get_hydra_output_dir, logger, seed_everything
10+
from ml_project_template.utils import ConfigKeys, get_output_dir, logger
1111
from ml_project_template.wandb import WandBRun
1212

1313

14-
def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
14+
def pre_call(root_config: DictConfig, seed_fn: Callable[[int], None] | None = None, log_debug: bool = False) -> None:
1515
"""Logs the config, sets the seed and initializes a WandB run before config instantiation.
1616
1717
Args:
1818
root_config: Unresolved config.
19+
seed_fn: Function to use for seeding the run.
1920
log_debug: Whether to log the config, seed and output path.
2021
"""
2122
if log_debug:
@@ -27,7 +28,10 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
2728
return
2829

2930
if (seed := config.get(ConfigKeys.SEED)) is not None:
30-
seed_everything(seed)
31+
if seed_fn is None:
32+
raise ValueError("No seeding function was set for the given seed.")
33+
34+
seed_fn(seed)
3135
logger.debug(f"Set seed to {seed}.")
3236
else:
3337
logger.warning("No seed was configured! Run may not be reproducible.")
@@ -37,7 +41,7 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
3741
else:
3842
logger.debug(f"Running config:\n{to_yaml(root_config)}")
3943

40-
output_path = get_hydra_output_dir()
44+
output_path = get_output_dir()
4145
logger.debug(f"Saving outputs in {output_path}")
4246

4347
logger.setLevel(logging.INFO)
@@ -48,17 +52,18 @@ def pre_call(root_config: DictConfig, log_debug: bool = False) -> None:
4852
wandb.save(output_path / ".hydra/*", base_path=output_path, policy="now")
4953

5054

51-
def run(main_function: Callable, log_debug: bool = True) -> None:
55+
def run(main_function: Callable, seed_fn: Callable[[int], None] | None = None, log_debug: bool = True) -> None:
5256
"""Configure and run a given function using hydra-zen.
5357
5458
Args:
5559
main_function: Function to configure and run.
60+
seed_fn: Function to use for seeding the run.
5661
log_debug: Whether to log debug information from the `pre_call` function.
5762
"""
5863
store.add_to_hydra_store()
5964
zen(
6065
main_function,
61-
pre_call=partial(pre_call, log_debug=log_debug),
66+
pre_call=partial(pre_call, seed_fn=seed_fn, log_debug=log_debug),
6267
resolve_pre_call=False,
6368
instantiation_wrapper=pydantic_parser,
6469
).hydra_main(

ml_project_template/runs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from submitit import AutoExecutor
1010
from submitit.helpers import CommandFunction
1111

12-
from ml_project_template.utils import ConfigKeys, get_hydra_output_dir, logger
12+
from ml_project_template.utils import ConfigKeys, get_output_dir, logger
1313
from ml_project_template.wandb import WandBConfig, WandBRun
1414

1515

@@ -80,7 +80,7 @@ def run(self) -> None:
8080

8181
function = CommandFunction(command)
8282
executor = AutoExecutor(
83-
folder=get_hydra_output_dir(),
83+
folder=get_output_dir(),
8484
cluster=self.cluster,
8585
slurm_python=self.python_command,
8686
)
@@ -134,7 +134,7 @@ def run(self) -> None:
134134
metric = {"goal": self.metric_goal, "name": self.metric_name}
135135
program, args = sys.argv[0], self.filter_args(sys.argv[1:])
136136

137-
folder_path = get_hydra_output_dir()
137+
folder_path = get_output_dir()
138138
dummy_sweep_id = "sweep_started_" + Path(folder_path).parts[-2] + "_" + Path(folder_path).parts[-1]
139139
hydra_run_dir = "./outputs/sweeps/" + dummy_sweep_id + "/${now:%H-%M-%S-%f}"
140140

@@ -179,6 +179,6 @@ def run(self) -> None:
179179
class Run:
180180
"""Configures a basic run."""
181181

182-
seed: int
182+
seed: int | None = None
183183
wandb: WandBRun | None = None
184184
job: Job | None = None

ml_project_template/utils.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from pathlib import Path
55
from typing import Final
66

7-
import numpy as np
8-
import torch
97
from hydra.core.hydra_config import HydraConfig
108

119
logger = logging.getLogger()
@@ -21,37 +19,25 @@ class ConfigKeys:
2119
STORE: Final[str] = "store"
2220

2321

24-
def seed_everything(seed: int) -> None:
25-
"""Seeds all random number generators.
22+
def basic_seed_fn(seed: int) -> None:
23+
"""Seeds random number generators.
2624
2725
Args:
2826
seed: Random seed.
2927
"""
3028
random.seed(seed)
31-
np.random.seed(seed)
32-
torch.manual_seed(seed)
33-
torch.cuda.manual_seed_all(seed)
3429
os.environ["PYTHONHASHSEED"] = str(seed)
3530

3631

37-
def get_device() -> str:
38-
"""Returns the available device for torch.
32+
def get_output_dir() -> Path:
33+
"""Get the current output directory.
3934
4035
Returns:
41-
The GPU or the MPS device when available and the CPU device as a fallback.
36+
Output path of the current run.
4237
"""
43-
if torch.cuda.is_available():
44-
return "cuda"
45-
elif torch.backends.mps.is_available():
46-
return "mps"
47-
else:
48-
return "cpu"
49-
50-
51-
def get_hydra_output_dir() -> Path:
52-
"""Return the hydra output directory.
53-
54-
Returns:
55-
Path to the hydra output directory.
56-
"""
57-
return Path(HydraConfig.get().runtime.output_dir)
38+
try:
39+
output_dir = Path(HydraConfig.get().runtime.output_dir)
40+
except ValueError:
41+
output_dir = Path("/tmp/outputs")
42+
output_dir.mkdir(exist_ok=True, parents=True)
43+
return output_dir

ml_project_template/wandb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Self
44

55
import wandb
6-
from dotenv import load_dotenv
6+
from dotenv import find_dotenv, load_dotenv
77
from wandb.wandb_run import Run
88

99
from ml_project_template.utils import logger
@@ -25,7 +25,7 @@ def from_env(cls) -> Self | None:
2525
Populated `WandBConfig` or None if environment variables could not be found.
2626
"""
2727
config = None
28-
load_dotenv()
28+
load_dotenv(find_dotenv(usecwd=True))
2929

3030
try:
3131
config = cls(**{field.name: os.environ[field.name] for field in fields(cls)})

pyproject.toml

Lines changed: 59 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,31 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "ml_project_template"
7-
version = "0.0.12"
7+
version = "0.0.14"
88
requires-python = ">=3.12"
99
dependencies = [
10-
"torch>=2.5.1",
1110
"wandb>=0.18.5",
1211
"hydra-zen>=0.14.0",
1312
"submitit>=1.5.1",
1413
"python-dotenv>=1.0.1",
15-
"numpy>=1.26.4",
1614
"pydantic>=2.10.3",
1715
"pytest>=8.3.5",
1816
]
1917

2018
[dependency-groups]
21-
dev = [
22-
"pre-commit>=4.0.1",
23-
"ruff>=0.8.0",
24-
]
19+
dev = ["pre-commit>=4.0.1", "ruff>=0.8.0"]
2520

2621
[tool.setuptools.packages.find]
2722
include = ["ml_project_template"]
2823
namespaces = true
2924

30-
[tool.uv.sources]
31-
torch = [
32-
{ index = "pytorch-cu124", marker = "platform_system != 'Darwin'" },
33-
]
34-
35-
[[tool.uv.index]]
36-
name = "pytorch-cu124"
37-
url = "https://download.pytorch.org/whl/cu124"
38-
explicit = true
39-
4025
[tool.pytest.ini_options]
4126
pythonpath = ["."]
4227

4328
[tool.mypy]
4429
explicit_package_bases = true
4530
disable_error_code = ["import-untyped"]
4631

47-
[[tool.mypy.overrides]]
48-
module = "cloudpathlib.*"
49-
ignore_errors = true
50-
follow_imports = "skip"
51-
5232
[tool.ruff]
5333
line-length = 119
5434
indent-width = 4
@@ -59,67 +39,63 @@ convention = "google"
5939
[tool.ruff.lint]
6040
select = ["ALL"]
6141
ignore = [
62-
"ANN002",
63-
"ANN003",
64-
"ANN401",
65-
"D413",
66-
"COM812",
67-
"D100",
68-
"D104",
69-
"D107",
70-
"D205",
71-
"PD901",
72-
"D400",
73-
"D401",
74-
"D415",
75-
"FA",
76-
"SLF",
77-
"INP",
78-
"TRY003",
79-
"TRY201",
80-
"EM",
81-
"FBT",
82-
"RET",
83-
"C406",
84-
"E741",
85-
"PLR2004",
86-
"RUF009",
87-
"RUF012",
88-
"BLE001",
89-
"S603",
90-
"S607",
91-
"S506",
92-
"FIX002",
93-
"NPY002",
94-
"G004",
95-
"S311",
96-
"PIE790",
97-
"TRY400",
98-
"S108",
99-
"W191",
100-
"E111",
101-
"E114",
102-
"E117",
103-
"D206",
104-
"D300",
105-
"Q000",
106-
"Q001",
107-
"Q002",
108-
"Q003",
109-
"COM812",
110-
"COM819",
111-
"D203",
112-
"D213",
113-
"N806",
114-
"N803",
115-
"E712",
116-
"PLR0913",
117-
"TC001"
42+
"ANN002",
43+
"ANN003",
44+
"ANN401",
45+
"D413",
46+
"COM812",
47+
"D100",
48+
"D104",
49+
"D107",
50+
"D205",
51+
"PD901",
52+
"D400",
53+
"D401",
54+
"D415",
55+
"FA",
56+
"SLF",
57+
"INP",
58+
"TRY003",
59+
"TRY201",
60+
"EM",
61+
"FBT",
62+
"RET",
63+
"C406",
64+
"E741",
65+
"PLR2004",
66+
"RUF009",
67+
"RUF012",
68+
"BLE001",
69+
"S603",
70+
"S607",
71+
"S506",
72+
"FIX002",
73+
"NPY002",
74+
"G004",
75+
"S311",
76+
"PIE790",
77+
"TRY400",
78+
"S108",
79+
"W191",
80+
"E111",
81+
"E114",
82+
"E117",
83+
"D206",
84+
"D300",
85+
"Q000",
86+
"Q001",
87+
"Q002",
88+
"Q003",
89+
"COM812",
90+
"COM819",
91+
"D203",
92+
"D213",
93+
"N806",
94+
"N803",
95+
"E712",
96+
"PLR0913",
97+
"TC001",
11898
]
11999

120100
[tool.ruff.lint.per-file-ignores]
121-
"**/tests/**/*.py" = [
122-
"S101",
123-
"ARG",
124-
"FBT",
125-
]
101+
"**/tests/**/*.py" = ["S101", "ARG", "FBT"]

0 commit comments

Comments
 (0)