@@ -101,6 +101,7 @@ class PSOSampler(optunahub.samplers.SimpleBaseSampler):
101101
102102 def __init__ (
103103 self ,
104+ * ,
104105 search_space : Optional [Dict [str , BaseDistribution ]] = None ,
105106 n_particles : int = 10 ,
106107 inertia : float = 0.5 ,
@@ -116,6 +117,7 @@ def __init__(
116117 if cognitive < 0.0 or social < 0.0 :
117118 raise ValueError ("cognitive and social must be >= 0.0." )
118119
120+ self .search_space = search_space
119121 self .n_particles : int = n_particles
120122 self .inertia : float = inertia
121123 self .cognitive : float = cognitive
@@ -128,7 +130,6 @@ def __init__(
128130 self .dim : int = 0
129131 # Numeric-only names used for PSO vectorization.
130132 self .param_names : List [str ] = [] # numeric param names
131- self .categorical_param_names : List [str ] = []
132133 self ._numeric_dists : Dict [str , BaseDistribution ] = {}
133134 self .lower_bound : np .ndarray = np .array ([], dtype = float )
134135 self .upper_bound : np .ndarray = np .array ([], dtype = float )
@@ -151,22 +152,18 @@ def _lazy_init(self, search_space: Dict[str, BaseDistribution]) -> None:
151152 """Initialize internal state based on the current search space (numeric-only for PSO)."""
152153 # Split numeric vs. categorical distributions.
153154 self .param_names = []
154- self .categorical_param_names = []
155- self ._numeric_dists = {}
156155
157- for name , dist in search_space .items ():
156+ self ._numeric_dists = {
157+ name : dist
158+ for name , dist in search_space .items ()
158159 if isinstance (
159160 dist ,
160161 (optuna .distributions .FloatDistribution , optuna .distributions .IntDistribution ),
161- ):
162- self .param_names .append (name ) # numeric params used by PSO
163- self ._numeric_dists [name ] = dist
164- elif isinstance (dist , optuna .distributions .CategoricalDistribution ):
165- self .categorical_param_names .append (name )
166- else :
167- # Unknown distribution types are ignored by PSO and will be sampled independently.
168- self .categorical_param_names .append (name )
162+ )
163+ and not dist .single ()
164+ }
169165
166+ self .param_names = sorted (self ._numeric_dists .keys ())
170167 self .dim = len (self .param_names )
171168
172169 if self .dim > 0 :
@@ -194,6 +191,28 @@ def _lazy_init(self, search_space: Dict[str, BaseDistribution]) -> None:
194191
195192 self ._initialized = True
196193
194+ def infer_relative_search_space (
195+ self , study : Study , _ : FrozenTrial
196+ ) -> Dict [str , BaseDistribution ]:
197+ if self .search_space is not None :
198+ return self .search_space
199+
200+ inferred = self ._intersection_search_space .calculate (study )
201+
202+ numeric = {
203+ n : d
204+ for n , d in inferred .items ()
205+ if not d .single ()
206+ and isinstance (
207+ d , (optuna .distributions .FloatDistribution , optuna .distributions .IntDistribution )
208+ )
209+ }
210+
211+ if numeric :
212+ self .search_space = numeric
213+
214+ return numeric
215+
197216 def sample_relative (
198217 self ,
199218 study : Study ,
@@ -210,15 +229,7 @@ def sample_relative(
210229 if len (search_space ) == 0 :
211230 return {}
212231
213- # Re-init if the count of numeric params changed.
214- numeric_count = sum (
215- isinstance (
216- dist ,
217- (optuna .distributions .FloatDistribution , optuna .distributions .IntDistribution ),
218- )
219- for dist in search_space .values ()
220- )
221- if not self ._initialized or self .dim != numeric_count :
232+ if not self ._initialized :
222233 self ._lazy_init (search_space )
223234
224235 # Serve next precomputed numeric candidate if available.
0 commit comments