Skip to content

Commit 9dfb91c

Browse files
LucaLuca
authored andcommitted
Fixed Issue #99 where resuming a run previously run in parallel (n_workers>1) caused the history to be disordered.
This fix adds the bracket_id to the history dataframe and sorts the history by (bracket_id, fidelity, config_id) to emulate a serial execution history without disordering allowing for the tell(resume=True) function to work as intended. Additionally, updated deprecated usage of config.get_dictionary() to dict(config).
1 parent 2d39900 commit 9dfb91c

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

src/dehb/optimizers/dehb.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def _save_incumbent(self):
833833
res = {}
834834
if self.use_configspace:
835835
config = self.vector_to_configspace(self.inc_config)
836-
res["config"] = config.get_dictionary()
836+
res["config"] = dict(config)
837837
else:
838838
res["config"] = self.inc_config.tolist()
839839
res["score"] = self.inc_score
@@ -850,8 +850,11 @@ def _save_history(self, name="history.parquet.gzip"):
850850
return
851851
try:
852852
history_path = self.output_path / name
853-
history_df = pd.DataFrame(self.history, columns=["config_id", "config", "fitness",
854-
"cost", "fidelity", "info"])
853+
# Persist bracket_id to reconstruct serial replay order later
854+
history_df = pd.DataFrame(
855+
self.history,
856+
columns=["bracket_id", "config_id", "config", "fitness", "cost", "fidelity", "info"],
857+
)
855858
# Check if the 'info' column is empty or contains only None values
856859
if history_df["info"].apply(lambda x: (isinstance(x, dict) and len(x) == 0)).all():
857860
# Drop the 'info' column
@@ -936,12 +939,20 @@ def _load_checkpoint(self, run_dir: str):
936939
history_path = run_dir / "history.parquet.gzip"
937940
history = pd.read_parquet(history_path)
938941

939-
# Replay history
942+
# Sort history to emulate serial execution order if bracket_id available
943+
if "bracket_id" in history.columns:
944+
history = history.sort_values(by=["bracket_id", "fidelity", "config_id"]).reset_index(drop=True)
945+
else:
946+
# Fallback ordering for older checkpoints
947+
history = history.sort_values(by=["fidelity", "config_id"]).reset_index(drop=True)
948+
949+
# Replay history in the chosen order
940950
for _, row in history.iterrows():
941951
job_info = {
942952
"fidelity": row["fidelity"],
943953
"config_id": row["config_id"],
944954
"config": np.array(row["config"]),
955+
**({"bracket_id": int(row["bracket_id"]) } if "bracket_id" in history.columns else {}),
945956
}
946957
result = {
947958
"fitness": row["fitness"],
@@ -993,6 +1004,8 @@ def tell(self, job_info: dict, result: dict, replay: bool=False) -> None:
9931004
job_info_container["fidelity"] = job_info["fidelity"]
9941005
job_info_container["config"] = job_info["config"]
9951006
job_info_container["config_id"] = job_info["config_id"]
1007+
if "bracket_id" in job_info:
1008+
job_info_container["bracket_id"] = job_info["bracket_id"]
9961009

9971010
# Update entry in ConfigRepository
9981011
self.config_repository.configs[job_info["config_id"]].config = job_info["config"]
@@ -1036,8 +1049,16 @@ def tell(self, job_info: dict, result: dict, replay: bool=False) -> None:
10361049
inc_changed = True
10371050
# book-keeping
10381051
self._update_trackers(
1039-
traj=self.inc_score, runtime=cost, history=(
1040-
config_id, config.tolist(), float(fitness), float(cost), float(fidelity), info,
1052+
traj=self.inc_score,
1053+
runtime=cost,
1054+
history=(
1055+
bracket_id,
1056+
config_id,
1057+
config.tolist(),
1058+
float(fitness),
1059+
float(cost),
1060+
float(fidelity),
1061+
info,
10411062
),
10421063
)
10431064

@@ -1167,7 +1188,7 @@ def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus
11671188
self.logger.info("Incumbent config: ")
11681189
if self.use_configspace:
11691190
config = self.vector_to_configspace(self.inc_config)
1170-
for k, v in config.get_dictionary().items():
1191+
for k, v in dict(config).items():
11711192
self.logger.info(f"{k}: {v}")
11721193
else:
11731194
self.logger.info(f"{self.inc_config}")

0 commit comments

Comments
 (0)