Skip to content

Commit 082eb57

Browse files
committed
style(gather_data): pre-commit
1 parent 4edb94a commit 082eb57

1 file changed

Lines changed: 47 additions & 54 deletions

File tree

carps/analysis/gather_data.py

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,31 @@
33
from __future__ import annotations
44

55
import os
6-
import json
6+
from functools import partial
7+
from multiprocessing import Pool
8+
from pathlib import Path
9+
710
import fire
8-
import pandas as pd
911
import numpy as np
10-
from pathlib import Path
11-
from multiprocessing import Pool
12-
from functools import partial
12+
import pandas as pd
13+
import yaml
1314
from omegaconf import OmegaConf
1415

1516
# Import existing carps utilities based on provided source logic
16-
from carps.analysis.gather_data_utils import (
17-
load_log,
18-
normalize_logs,
19-
convert_mixed_types_to_str,
20-
read_jsonl_content
21-
)
17+
from carps.analysis.gather_data_utils import convert_mixed_types_to_str, load_log, normalize_logs, read_jsonl_content
2218
from carps.utils.check_missing import generate_commands
2319
from carps.utils.loggingutils import get_logger, setup_logging
2420
from carps.utils.types import RunStatus
2521

2622
setup_logging()
2723
logger = get_logger(__file__)
2824

25+
2926
def get_run_info(config_path: Path, log_fn: str = "trial_logs.jsonl") -> dict:
30-
"""
31-
Combined worker function: Determines the execution status of a run and loads its log data.
27+
"""Combined worker function: Determines the execution status of a run and loads its log data.
3228
33-
This function serves as the core processing unit for a single experiment directory.
34-
It identifies whether a run is Completed, Truncated, or Missing based on the
29+
This function serves as the core processing unit for a single experiment directory.
30+
It identifies whether a run is Completed, Truncated, or Missing based on the
3531
expected number of trials in the config versus the actual trials in the logs.
3632
3733
Args:
@@ -49,11 +45,11 @@ def get_run_info(config_path: Path, log_fn: str = "trial_logs.jsonl") -> dict:
4945
rundir = config_path.parent.parent
5046
status = RunStatus.MISSING
5147
log_df = pd.DataFrame()
52-
48+
5349
# 1. Load Config
5450
try:
5551
cfg = OmegaConf.load(config_path)
56-
except Exception as e:
52+
except FileNotFoundError as e:
5753
logger.error(f"Could not load config at {config_path}: {e}")
5854
return {}
5955

@@ -63,26 +59,24 @@ def get_run_info(config_path: Path, log_fn: str = "trial_logs.jsonl") -> dict:
6359
# 2. Determine Status (Logic from check_missing.py)
6460
n_trials = cfg.task.optimization_resources.n_trials
6561
trial_logs_fn = rundir / log_fn
66-
62+
6763
if trial_logs_fn.is_file():
68-
try:
69-
# Check trial counts to determine if run finished
70-
trial_logs = read_jsonl_content(str(trial_logs_fn))
71-
if not trial_logs.empty and "n_trials" in trial_logs:
72-
n_trials_done = trial_logs["n_trials"].max()
73-
status = RunStatus.COMPLETED if n_trials_done >= n_trials else RunStatus.TRUNCATED
74-
75-
# 3. Load and Process Log Data (Logic from gather_data.py)
76-
log_df = load_log(rundir, log_fn=log_fn)
77-
except Exception as e:
78-
logger.warning(f"Error processing logs in {rundir}: {e}")
64+
# Check trial counts to determine if run finished
65+
trial_logs = read_jsonl_content(str(trial_logs_fn))
66+
if not trial_logs.empty and "n_trials" in trial_logs:
67+
n_trials_done = trial_logs["n_trials"].max()
68+
status = RunStatus.COMPLETED if n_trials_done >= n_trials else RunStatus.TRUNCATED
69+
70+
# 3. Load and Process Log Data (Logic from gather_data.py)
71+
log_df = load_log(rundir, log_fn=log_fn)
7972

8073
# 4. Extract Overrides for command generation
8174
try:
8275
hydra_cfg = OmegaConf.load(config_path.parent / "hydra.yaml")
8376
task_overrides = hydra_cfg.hydra.overrides.task
8477
hydra_overrides = hydra_cfg.hydra.overrides.hydra
85-
except Exception:
78+
except yaml.reader.ReaderError:
79+
logger.warning(f"Could not load overrides from {config_path.parent / 'hydra.yaml'}.")
8680
task_overrides = []
8781
hydra_overrides = []
8882

@@ -100,41 +94,41 @@ def get_run_info(config_path: Path, log_fn: str = "trial_logs.jsonl") -> dict:
10094
"status_info": status_info,
10195
"log_df": log_df,
10296
"cfg_fn": str(config_path),
103-
"cfg_str": OmegaConf.to_yaml(cfg).replace("\n", "\\n")
97+
"cfg_str": OmegaConf.to_yaml(cfg).replace("\n", "\\n"),
10498
}
10599

100+
106101
def gather_and_check(
107102
rundir: str | list[str],
108103
log_fn: str = "trial_logs.jsonl",
109104
n_processes: int | None = None,
110-
outdir: str | Path | None = None
105+
outdir: str | Path | None = None,
111106
) -> None:
112-
"""
113-
Scans directories to gather performance logs and check for missing/truncated runs.
107+
"""Scans directories to gather performance logs and check for missing/truncated runs.
114108
115-
This is the main entry point. It performs a parallel scan of the provided directories,
116-
generates a status report (`runstatus.csv`), creates shell scripts to restart failed
117-
runs (`runcommands_*.sh`), and aggregates all valid trial data into consolidated
109+
This is the main entry point. It performs a parallel scan of the provided directories,
110+
generates a status report (`runstatus.csv`), creates shell scripts to restart failed
111+
runs (`runcommands_*.sh`), and aggregates all valid trial data into consolidated
118112
CSV and Parquet files.
119113
120114
Args:
121115
rundir (str | list[str]): One or more directories to scan for results.
122116
log_fn (str): The filename of the trial logs. Defaults to "trial_logs.jsonl".
123-
n_processes (int | None): Number of CPU processes for parallel processing.
117+
n_processes (int | None): Number of CPU processes for parallel processing.
124118
Defaults to None (uses all available cores).
125-
outdir (str | Path | None): Directory where output files will be saved.
119+
outdir (str | Path | None): Directory where output files will be saved.
126120
If None, uses the common path of input rundirs.
127121
128122
Returns:
129123
None: Outputs files directly to the file system (logs.csv, runstatus.csv, etc.).
130124
"""
131125
if isinstance(rundir, str):
132126
rundir = [rundir]
133-
127+
134128
all_status_data = []
135129
all_log_dfs = []
136130
config_mappings = []
137-
131+
138132
for r in rundir:
139133
logger.info(f"Scanning {r} for experiment configs...")
140134
# Find every experiment directory via its hydra config
@@ -146,7 +140,8 @@ def gather_and_check(
146140
results = pool.map(worker, config_paths)
147141

148142
for res in results:
149-
if not res: continue
143+
if not res:
144+
continue
150145
all_status_data.append(res["status_info"])
151146
if not res["log_df"].empty:
152147
# Store log and track config for cfg_str/cfg_fn mapping
@@ -155,15 +150,12 @@ def gather_and_check(
155150

156151
# --- PART 1: Handle Status and Run-Commands ---
157152
status_df = pd.DataFrame(all_status_data).dropna()
158-
if outdir is None:
159-
outdir = Path(os.path.commonpath(rundir))
160-
else:
161-
outdir = Path(outdir)
153+
outdir = Path(os.path.commonpath(rundir)) if outdir is None else Path(outdir)
162154
outdir.mkdir(parents=True, exist_ok=True)
163-
155+
164156
status_df.to_csv(outdir / "runstatus.csv", index=False)
165157
logger.info(f"Saved run status to {outdir / 'runstatus.csv'}")
166-
158+
167159
# Generate shell scripts to fix non-completed runs
168160
generate_commands(status_df, RunStatus.MISSING, str(outdir))
169161
generate_commands(status_df, RunStatus.TRUNCATED, str(outdir))
@@ -172,15 +164,15 @@ def gather_and_check(
172164
if all_log_dfs:
173165
logger.info("Consolidating and normalizing logs...")
174166
df = pd.concat(all_log_dfs).reset_index(drop=True)
175-
167+
176168
# Create metadata mapping between experiments and their config strings
177169
df_cfg = pd.DataFrame(config_mappings).drop_duplicates()
178170
df_cfg["experiment_id"] = np.arange(len(df_cfg))
179-
171+
180172
# Assign experiment_id back to main log dataframe
181-
mapping = dict(zip(df_cfg["cfg_fn"], df_cfg["experiment_id"]))
173+
mapping = dict(zip(df_cfg["cfg_fn"], df_cfg["experiment_id"], strict=False))
182174
df["experiment_id"] = df["cfg_fn"].map(mapping)
183-
175+
184176
# Apply normalization and cleanup
185177
df = normalize_logs(df)
186178
df = convert_mixed_types_to_str(df)
@@ -191,12 +183,13 @@ def gather_and_check(
191183
df.to_parquet(outdir / "logs.parquet", index=False)
192184
df_cfg.to_csv(outdir / "logs_cfg.csv", index=False)
193185
df_cfg.to_parquet(outdir / "logs_cfg.parquet", index=False)
194-
186+
195187
logger.info(f"Gathered logs for {len(all_log_dfs)} runs into {outdir}")
196188
else:
197189
logger.warning("No log data found to gather.")
198190

199191
logger.info("Done! 😊")
200192

193+
201194
if __name__ == "__main__":
202-
fire.Fire(gather_and_check)
195+
fire.Fire(gather_and_check)

0 commit comments

Comments
 (0)