Skip to content

Add HilbertGridSearchOptimizer for space-filling curve traversal #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/gradient_free_optimizers/optimizers/grid/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..base_optimizer import BaseOptimizer
from .diagonal_grid_search import DiagonalGridSearchOptimizer
from .orthogonal_grid_search import OrthogonalGridSearchOptimizer
from .hilbert_grid_search import HilbertGridSearchOptimizer


class GridSearchOptimizer(BaseOptimizer):
Expand Down Expand Up @@ -58,14 +59,23 @@ def __init__(
nth_process=nth_process,
step_size=step_size,
)
elif direction == "hilbert":
self.grid_search_opt = HilbertGridSearchOptimizer(
search_space=search_space,
initialize=initialize,
constraints=constraints,
random_state=random_state,
rand_rest_p=rand_rest_p,
nth_process=nth_process,
step_size=step_size,
)
else:
msg = ""
raise Exception(msg)
raise ValueError("Invalid direction. Choose 'orthogonal', 'diagonal', or 'hilbert'.")

@BaseOptimizer.track_new_pos
def iterate(self):
return self.grid_search_opt.iterate()

@BaseOptimizer.track_new_score
def evaluate(self, score_new):
self.grid_search_opt.evaluate(score_new)
self.grid_search_opt.evaluate(score_new)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# gradient_free_optimizers/hilbert_grid_search.py
# Author: Simon Blanke
# Email: [email protected]
# License: MIT License

import numpy as np
from numpy_hilbert_curve import decode
from ..base_optimizer import BaseOptimizer


class HilbertGridSearchOptimizer(BaseOptimizer):
def __init__(
self,
search_space,
initialize={"grid": 4, "random": 2, "vertices": 4},
constraints=[],
random_state=None,
rand_rest_p=0,
nth_process=None,
step_size=1,
):
super().__init__(
search_space=search_space,
initialize=initialize,
constraints=constraints,
random_state=random_state,
rand_rest_p=rand_rest_p,
nth_process=nth_process,
)
self.step_size = step_size
self.Z = 0 # Current Hilbert integer
self.valid_count = 0 # Counter for valid points

def hilbert_move(self):
while True:
# Decode the current Hilbert integer to get an nD point
point = decode(np.array([self.Z]), self.conv.n_dim, self.conv.n_dim)[0]
self.Z += 1
# Check if the point is within the grid bounds
if all(point[i] < self.conv.dim_sizes[i] for i in range(self.conv.n_dim)):
self.valid_count += 1
# Take every step_size-th valid point
if self.valid_count % self.step_size == 1:
return np.array(point)
# Continue if point is out of bounds

@BaseOptimizer.track_new_pos
def iterate(self):
pos_new = self.hilbert_move()
pos_new = self.conv2pos(pos_new)
return pos_new

@BaseOptimizer.track_new_score
def evaluate(self, score_new):
BaseOptimizer.evaluate(self, score_new)