GridSearchOracle: persist _ordered_ids and _populate_next across reloads (#1055)#1056
GridSearchOracle: persist _ordered_ids and _populate_next across reloads (#1055)#1056SAY-5 wants to merge 3 commits into
Conversation
…ads (keras-team#1055) Signed-off-by: SAY-5 <say.apm35@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request implements state persistence for GridSearch by adding get_state and set_state methods to the oracle, ensuring that search progress is correctly maintained across process restarts. It includes regression tests for state round-trips and backward compatibility with legacy checkpoints. Feedback was provided regarding the persistence of the LinkedList state, noting that directly saving the internal memory list may lead to corrupted trial ordering; a traversal-based approach was suggested to correctly capture the logical sequence of trial IDs.
| # `GridSearch(..., overwrite=False)` after a kernel restart) keeps | ||
| # working without `KeyError` from `_ordered_ids.next()` (#1055). | ||
| state = super().get_state() | ||
| state["ordered_ids"] = list(self._ordered_ids._memory) |
There was a problem hiding this comment.
The _ordered_ids._memory list stores trial IDs in the order they were inserted into the LinkedList, which does not necessarily correspond to the logical sequence of the grid traversal. This happens when new hyperparameters are discovered or when using conditional search spaces, causing trials to be inserted in the middle of the list. Persisting _memory directly and then rebuilding the list by assuming its order is the sequence will result in a corrupted linked list upon search resumption. Instead, the linked list should be traversed using its next pointers to capture the correct logical sequence of trial IDs.
| state["ordered_ids"] = list(self._ordered_ids._memory) | |
| ordered_ids = [] | |
| current_index = self._ordered_ids._next_index[None] | |
| while current_index is not None: | |
| ordered_ids.append(self._ordered_ids._memory[current_index]) | |
| current_index = self._ordered_ids._next_index[current_index] | |
| state["ordered_ids"] = ordered_ids |
Closes #1055.
GridSearchOracle.__init__builds two in-memory bookkeeping fields not present inOracle:_ordered_ids: aLinkedListof trial ids in hp-combo order_populate_next: a queue of trial ids ready to produce their next combinationNeither is in
Oracle.get_state/set_stateandGridSearchOracledoesn't override them, so after a process restart withoverwrite=Falsethey reload empty.start_orderrehydrates fine, the firstend_trialpushes a trial id onto_populate_next, and the nextpopulate_spacecalls_ordered_ids.next(old_trial_id)against an empty_data_to_index→KeyError: '<trial_id>'atgridsearch.py:80,197.Override
get_state/set_stateonGridSearchOracleto persist both fields and rebuild theLinkedListfrom them on reload. Older state files (saved before this change) lazily rebuild fromstart_orderso existing checkpoints keep working.Regression tests in
gridsearch_test.pycover:KeyError)start_order