|
38 | 38 |
|
39 | 39 | import torch |
40 | 40 | from botorch.exceptions.errors import UnsupportedError |
41 | | -from botorch.fit import FitGPyTorchMLL |
42 | 41 | from botorch.models import SingleTaskGP |
43 | 42 | from botorch.models.likelihoods.sparse_outlier_noise import ( |
44 | 43 | SparseOutlierGaussianLikelihood, |
@@ -157,6 +156,103 @@ def load_standard_model(self, standard_model: Model) -> Self: |
157 | 156 | self.load_state_dict(standard_model.state_dict()) |
158 | 157 | return self |
159 | 158 |
|
| 159 | + def custom_fit( |
| 160 | + self, |
| 161 | + mll: MarginalLogLikelihood, |
| 162 | + *, |
| 163 | + numbers_of_outliers: list[int] | None = None, |
| 164 | + fractions_of_outliers: list[float] | None = None, |
| 165 | + timeout_sec: float | None = None, |
| 166 | + relevance_pursuit_optimizer: Callable = backward_relevance_pursuit, |
| 167 | + reset_parameters: bool = True, |
| 168 | + reset_dense_parameters: bool = False, |
| 169 | + closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None, |
| 170 | + optimizer: Callable | None = None, |
| 171 | + closure_kwargs: dict[str, Any] | None = None, |
| 172 | + optimizer_kwargs: Mapping[str, Any] | None = None, |
| 173 | + ) -> MarginalLogLikelihood: |
| 174 | + """Fits a RobustRelevancePursuitGP model using the given marginal likelihood. |
| 175 | +
|
| 176 | + For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. |
| 177 | +
|
| 178 | + Args: |
| 179 | + mll: The marginal likelihood to fit. |
| 180 | + numbers_of_outliers: An optional list of numbers of outliers to consider |
| 181 | + during relevance pursuit. By default, the algorithm falls back to a |
| 182 | + default list of fractions of outliers, see below. |
| 183 | + fractions_of_outliers: An optional list of fractions of outliers to |
| 184 | + consider if numbers_of_outliers is None. By default, the algorithm |
| 185 | + uses ``[0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]``. |
| 186 | + relevance_pursuit_optimizer: The relevance pursuit optimizer to use. |
| 187 | + reset_parameters: If True, reset sparse parameters after each iteration. |
| 188 | + reset_dense_parameters: If True, reset dense parameters after each |
| 189 | + iteration. |
| 190 | + closure: A closure to compute loss and gradients. |
| 191 | + optimizer: The numerical optimizer. |
| 192 | + closure_kwargs: Additional arguments to pass to the closure. |
| 193 | + optimizer_kwargs: Additional arguments to pass to fit_gpytorch_mll. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + The fitted marginal likelihood. |
| 197 | + """ |
| 198 | + if isinstance(mll, _ApproximateMarginalLogLikelihood): |
| 199 | + raise UnsupportedError( |
| 200 | + "Relevance Pursuit does not yet support approximate inference. " |
| 201 | + ) |
| 202 | + |
| 203 | + sparse_module = SparseOutlierNoise._from_model(mll.model) |
| 204 | + n = sparse_module.dim # equal to the number of training data points |
| 205 | + |
| 206 | + if numbers_of_outliers is None: |
| 207 | + if fractions_of_outliers is None: |
| 208 | + fractions_of_outliers = FRACTIONS_OF_OUTLIERS |
| 209 | + |
| 210 | + # list from which BMC chooses |
| 211 | + numbers_of_outliers = [int(p * n) for p in fractions_of_outliers] |
| 212 | + |
| 213 | + optimizer_kwargs_: dict[str, Any] = ( |
| 214 | + {} if optimizer_kwargs is None else dict(optimizer_kwargs) |
| 215 | + ) |
| 216 | + if timeout_sec is not None: |
| 217 | + optimizer_kwargs_["timeout_sec"] = timeout_sec / len(numbers_of_outliers) |
| 218 | + |
| 219 | + # Need to convert model to avoid recursion through fit_gpytorch_mll, |
| 220 | + # since relevance pursuit expects to call the base fit_gpytorch_mll. |
| 221 | + original_model = mll.model # Robust Relevance Pursuit Model |
| 222 | + mll.model = original_model.to_standard_model() |
| 223 | + sparse_module = SparseOutlierNoise._from_model(mll.model) |
| 224 | + sparse_module, model_trace = relevance_pursuit_optimizer( |
| 225 | + sparse_module=sparse_module, |
| 226 | + mll=mll, |
| 227 | + sparsity_levels=numbers_of_outliers, |
| 228 | + reset_parameters=reset_parameters, |
| 229 | + reset_dense_parameters=reset_dense_parameters, |
| 230 | + record_model_trace=True, |
| 231 | + # These are the args of the canonical mll fit routine |
| 232 | + closure=closure, |
| 233 | + optimizer=optimizer, |
| 234 | + closure_kwargs=closure_kwargs, |
| 235 | + optimizer_kwargs=optimizer_kwargs_, |
| 236 | + ) |
| 237 | + |
| 238 | + # Bayesian model comparison |
| 239 | + bmc_support_sizes, bmc_probabilities = get_posterior_over_support( |
| 240 | + SparseOutlierNoise, |
| 241 | + model_trace, |
| 242 | + prior_mean_of_support=original_model.prior_mean_of_support, |
| 243 | + ) |
| 244 | + map_index = torch.argmax(bmc_probabilities) |
| 245 | + map_model = model_trace[map_index] # choosing model with highest BMC score |
| 246 | + # overwrite mll.model with chosen model |
| 247 | + mll.model = original_model # first restore original model pointer |
| 248 | + mll.model.load_standard_model(map_model) |
| 249 | + # Store the bmc results |
| 250 | + mll.model.bmc_support_sizes = bmc_support_sizes |
| 251 | + mll.model.bmc_probabilities = bmc_probabilities |
| 252 | + if mll.model.cache_model_trace: |
| 253 | + mll.model.model_trace = model_trace |
| 254 | + return mll |
| 255 | + |
160 | 256 |
|
161 | 257 | class RobustRelevancePursuitSingleTaskGP(SingleTaskGP, RobustRelevancePursuitMixin): |
162 | 258 | def __init__( |
@@ -252,127 +348,3 @@ def to_standard_model(self) -> Model: |
252 | 348 | if not is_training: |
253 | 349 | model.eval() |
254 | 350 | return model |
255 | | - |
256 | | - |
257 | | -@FitGPyTorchMLL.register( |
258 | | - MarginalLogLikelihood, |
259 | | - SparseOutlierGaussianLikelihood, |
260 | | - RobustRelevancePursuitMixin, |
261 | | -) |
262 | | -def _fit_rrp( |
263 | | - mll: MarginalLogLikelihood, |
264 | | - _: type[SparseOutlierGaussianLikelihood], |
265 | | - __: type[RobustRelevancePursuitMixin], |
266 | | - *, |
267 | | - numbers_of_outliers: list[int] | None = None, |
268 | | - fractions_of_outliers: list[float] | None = None, |
269 | | - timeout_sec: float | None = None, |
270 | | - relevance_pursuit_optimizer: Callable = backward_relevance_pursuit, |
271 | | - reset_parameters: bool = True, |
272 | | - reset_dense_parameters: bool = False, |
273 | | - # fit_gpytorch_mll kwargs |
274 | | - closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None, |
275 | | - optimizer: Callable | None = None, |
276 | | - closure_kwargs: dict[str, Any] | None = None, |
277 | | - optimizer_kwargs: Mapping[str, Any] | None = None, |
278 | | -) -> MarginalLogLikelihood: |
279 | | - """Fits a RobustRelevancePursuitGP model using the given marginal likelihood. |
280 | | -
|
281 | | - For details, see [Ament2024pursuit]_ or https://arxiv.org/abs/2410.24222. |
282 | | -
|
283 | | - Args: |
284 | | - mll: The marginal likelihood to fit. |
285 | | - _: A likelihood, only directly used for dispatching. |
286 | | - _: A model, only directly used for dispatching. |
287 | | - numbers_of_outliers: An optional list of numbers of outliers to consider during |
288 | | - relevance pursuit. By default, the algorithm falls back to a default list |
289 | | - of fractions of outliers, see below. |
290 | | - fractions_of_outliers: An optional list of fractions of outliers to consider if |
291 | | - numbers_of_outliers is None. By default, the algorithm uses |
292 | | - ``[0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0]``. |
293 | | - relevance_pursuit_optimizer: The relevance pursuit optimizer to use. By default, |
294 | | - uses ``backward_relevance_pursuit``, which is generally the most powerful |
295 | | - algorithm for challenging problems with a wide range of outliers. The |
296 | | - ``forward_relevance_pursuit`` algorithm can be efficient when the number of |
297 | | - outliers is relatively small. |
298 | | - reset_parameters: If True, we will reset the sparse parameters of the model |
299 | | - after each iteration of the relevance pursuit algorithm. |
300 | | - reset_dense_parameters: If True, we will reset the dense parameters of the model |
301 | | - after each iteration of the relevance pursuit algorithm. |
302 | | - closure: A closure to use to compute the loss and the gradients, see docstring |
303 | | - of ``fit_gpytorch_mll`` for details. |
304 | | - optimizer: The numerical optimizer, see docstring of ``fit_gpytorch_mll``. |
305 | | - closure_kwargs: Additional arguments to pass to the ``closure`` function. |
306 | | - optimizer_kwargs: Additional arguments to pass to ``fit_gpytorch_mll``. |
307 | | -
|
308 | | - Returns: |
309 | | - The fitted marginal likelihood. |
310 | | - """ |
311 | | - sparse_module = SparseOutlierNoise._from_model(mll.model) |
312 | | - n = sparse_module.dim # equal to the number of training data points |
313 | | - |
314 | | - if numbers_of_outliers is None: |
315 | | - if fractions_of_outliers is None: |
316 | | - fractions_of_outliers = FRACTIONS_OF_OUTLIERS |
317 | | - |
318 | | - # list from which BMC chooses |
319 | | - numbers_of_outliers = [int(p * n) for p in fractions_of_outliers] |
320 | | - |
321 | | - optimizer_kwargs_: dict[str, Any] = ( |
322 | | - {} if optimizer_kwargs is None else dict(optimizer_kwargs) |
323 | | - ) |
324 | | - if timeout_sec is not None: |
325 | | - optimizer_kwargs_["timeout_sec"] = timeout_sec / len(numbers_of_outliers) |
326 | | - |
327 | | - # Need to convert model to avoid recursion through fit_gpytorch_mll dispatch, since |
328 | | - # relevance pursuit expects to call the base fit_gpytorch_mll. |
329 | | - original_model = mll.model # Robust Relevance Pursuit Model |
330 | | - mll.model = original_model.to_standard_model() |
331 | | - sparse_module = SparseOutlierNoise._from_model(mll.model) |
332 | | - sparse_module, model_trace = relevance_pursuit_optimizer( |
333 | | - sparse_module=sparse_module, |
334 | | - mll=mll, |
335 | | - sparsity_levels=numbers_of_outliers, |
336 | | - reset_parameters=reset_parameters, |
337 | | - reset_dense_parameters=reset_dense_parameters, |
338 | | - record_model_trace=True, |
339 | | - # These are the args of the canonical mll fit routine |
340 | | - closure=closure, |
341 | | - optimizer=optimizer, |
342 | | - closure_kwargs=closure_kwargs, |
343 | | - optimizer_kwargs=optimizer_kwargs_, |
344 | | - ) |
345 | | - |
346 | | - # Bayesian model comparison |
347 | | - bmc_support_sizes, bmc_probabilities = get_posterior_over_support( |
348 | | - SparseOutlierNoise, |
349 | | - model_trace, |
350 | | - prior_mean_of_support=original_model.prior_mean_of_support, |
351 | | - ) |
352 | | - map_index = torch.argmax(bmc_probabilities) |
353 | | - map_model = model_trace[map_index] # choosing model with highest BMC score |
354 | | - # overwrite mll.model with chosen model |
355 | | - mll.model = original_model # first restore original model pointer |
356 | | - mll.model.load_standard_model(map_model) |
357 | | - # Store the bmc results |
358 | | - mll.model.bmc_support_sizes = bmc_support_sizes |
359 | | - mll.model.bmc_probabilities = bmc_probabilities |
360 | | - if mll.model.cache_model_trace: |
361 | | - mll.model.model_trace = model_trace |
362 | | - return mll |
363 | | - |
364 | | - |
365 | | -@FitGPyTorchMLL.register( |
366 | | - _ApproximateMarginalLogLikelihood, |
367 | | - SparseOutlierGaussianLikelihood, |
368 | | - RobustRelevancePursuitMixin, |
369 | | -) |
370 | | -def _fit_rrp_approximate_mll( |
371 | | - mll: _ApproximateMarginalLogLikelihood, |
372 | | - _: type[SparseOutlierGaussianLikelihood], |
373 | | - __: type[RobustRelevancePursuitMixin], |
374 | | - **kwargs: Any, |
375 | | -) -> None: |
376 | | - raise UnsupportedError( |
377 | | - "Relevance Pursuit does not yet support approximate inference. " |
378 | | - ) |
0 commit comments