Skip to content

Commit 3f5805c

Browse files
committed
Refactor check_environment in sanity.py: add fn_pdc parameter for custom dataset preprocessing logic
1 parent 9c6f7bb commit 3f5805c

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

mle/sanity.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Callable
22

33
import torch
44
from erbium.api import get_all_gpu_info
@@ -10,15 +10,14 @@
1010
from mle.vars import ExpConfig
1111

1212

13-
def check_environment(config: ExpConfig) -> dict[str, Any]:
13+
def check_environment(config: ExpConfig, *, fn_pdc: Callable[[ExpConfig], str] = check_preprocessed_dataset) -> dict[
14+
str, Any]:
1415
try:
1516
gpus = get_all_gpu_info()
1617
except NVMLError:
1718
gpus = {}
18-
return {
19-
"dataset": check_dataset(config), "preprocessed_dataset": check_preprocessed_dataset(config), "gpus": gpus,
20-
"cuda": torch.version.cuda
21-
}
19+
return {"dataset": check_dataset(config), "preprocessed_dataset": fn_pdc(config), "gpus": gpus,
20+
"cuda": torch.version.cuda}
2221

2322

2423
def print_environment_check_results(results: dict[str, Any], *, console: Console = Console()) -> None:

0 commit comments

Comments
 (0)