Skip to content

Commit 86d7319

Browse files
authored
Merge pull request #900 from alan-turing-institute/879-registry
Add emulator registry (#879)
2 parents 067c3b4 + 49898af commit 86d7319

File tree

6 files changed

+495
-96
lines changed

6 files changed

+495
-96
lines changed

autoemulate/core/compare.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,7 @@
3737
TransformedEmulatorParams,
3838
)
3939
from autoemulate.data.utils import ConversionMixin, set_random_seed
40-
from autoemulate.emulators import (
41-
ALL_EMULATORS,
42-
DEFAULT_EMULATORS,
43-
PYTORCH_EMULATORS,
44-
get_emulator_class,
45-
)
40+
from autoemulate.emulators import _default_registry, get_emulator_class
4641
from autoemulate.emulators.base import Emulator, ProbabilisticEmulator
4742
from autoemulate.emulators.transformed.base import TransformedEmulator
4843
from autoemulate.transforms.base import AutoEmulateTransform
@@ -202,22 +197,26 @@ def __init__(
202197
@staticmethod
203198
def all_emulators() -> list[type[Emulator]]:
204199
"""Return a list of all available emulators."""
205-
return ALL_EMULATORS
200+
return _default_registry.all_emulators
206201

207202
@staticmethod
208203
def default_emulators() -> list[type[Emulator]]:
209204
"""Return a list of default emulators used by AutoEmulate."""
210-
return DEFAULT_EMULATORS
205+
return _default_registry.default_emulators
211206

212207
@staticmethod
213208
def pytorch_emulators() -> list[type[Emulator]]:
214209
"""Return a list of all available PyTorch emulators."""
215-
return PYTORCH_EMULATORS
210+
return _default_registry.pytorch_emulators
216211

217212
@staticmethod
218213
def probablistic_emulators() -> list[type[Emulator]]:
219214
"""Return a list of all available probabilistic emulators."""
220-
return [emulator for emulator in ALL_EMULATORS if emulator.supports_uq]
215+
return [
216+
emulator
217+
for emulator in _default_registry.all_emulators
218+
if emulator.supports_uq
219+
]
221220

222221
@staticmethod
223222
def list_emulators(default_only: bool = True) -> pd.DataFrame:
@@ -243,7 +242,11 @@ def list_emulators(default_only: bool = True) -> pd.DataFrame:
243242
- 'Uncertainty_Quantification',
244243
- 'Automatic_Differentiation`
245244
"""
246-
emulator_set = DEFAULT_EMULATORS if default_only else ALL_EMULATORS
245+
emulator_set = (
246+
_default_registry.default_emulators
247+
if default_only
248+
else _default_registry.all_emulators
249+
)
247250
return pd.DataFrame(
248251
{
249252
"Emulator": [emulator.model_name() for emulator in emulator_set],

autoemulate/emulators/__init__.py

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import Emulator, GaussianProcessEmulator
1+
from .base import Emulator
22
from .ensemble import EnsembleMLP, EnsembleMLPDropout
33
from .gaussian_process.exact import (
44
GaussianProcessCorrelatedMatern32,
@@ -11,74 +11,22 @@
1111
from .polynomials import PolynomialRegression
1212
from .radial_basis_functions import RadialBasisFunctions
1313
from .random_forest import RandomForest
14+
from .registry import Registry, _default_registry, get_emulator_class, register
1415
from .svm import SupportVectorMachine
1516
from .transformed.base import TransformedEmulator
1617

17-
DEFAULT_EMULATORS: list[type[Emulator]] = [
18-
GaussianProcessMatern32,
19-
GaussianProcessRBF,
20-
RadialBasisFunctions,
21-
PolynomialRegression,
22-
MLP,
23-
EnsembleMLP,
24-
]
25-
26-
# listing non pytorch emulators as we do not expect this list to grow
27-
NON_PYTORCH_EMULATORS: list[type[Emulator]] = [
28-
LightGBM,
29-
SupportVectorMachine,
30-
RandomForest,
31-
]
32-
33-
ALL_EMULATORS: list[type[Emulator]] = [
34-
*DEFAULT_EMULATORS,
35-
*NON_PYTORCH_EMULATORS,
36-
GaussianProcessCorrelatedMatern32,
37-
GaussianProcessCorrelatedRBF,
38-
EnsembleMLPDropout,
39-
]
40-
41-
PYTORCH_EMULATORS: list[type[Emulator]] = [
42-
emulator for emulator in ALL_EMULATORS if emulator not in NON_PYTORCH_EMULATORS
43-
]
44-
GAUSSIAN_PROCESS_EMULATORS: list[type[Emulator]] = [
45-
emulator
46-
for emulator in ALL_EMULATORS
47-
if issubclass(emulator, GaussianProcessEmulator)
48-
]
49-
50-
EMULATOR_REGISTRY = {em_cls.model_name().lower(): em_cls for em_cls in ALL_EMULATORS}
51-
EMULATOR_REGISTRY_SHORT_NAME = {em_cls.short_name(): em_cls for em_cls in ALL_EMULATORS}
52-
53-
54-
def get_emulator_class(name: str) -> type[Emulator]:
55-
"""
56-
Get the emulator class by name.
57-
58-
Parameters
59-
----------
60-
name: str
61-
The name of the emulator class.
62-
63-
Returns
64-
-------
65-
type[Emulator] | None
66-
The emulator class if found, None otherwise.
67-
"""
68-
emulator_cls = EMULATOR_REGISTRY.get(
69-
name.lower()
70-
) or EMULATOR_REGISTRY_SHORT_NAME.get(name.lower())
71-
72-
if emulator_cls is None:
73-
raise ValueError(
74-
f"Unknown emulator name: {name}.Available: {list(EMULATOR_REGISTRY.keys())}"
75-
)
76-
77-
return emulator_cls
78-
18+
# Module-level constants for backward compatibility and simplified public access
19+
DEFAULT_EMULATORS = _default_registry._default_emulators
20+
NON_PYTORCH_EMULATORS = _default_registry._non_pytorch_emulators
21+
ALL_EMULATORS = _default_registry._all_emulators
22+
PYTORCH_EMULATORS = _default_registry._pytorch_emulators
23+
GAUSSIAN_PROCESS_EMULATORS = _default_registry._gaussian_process_emulators
24+
EMULATOR_REGISTRY = _default_registry._emulator_registry
25+
EMULATOR_REGISTRY_SHORT_NAME = _default_registry._emulator_registry_short_name
7926

8027
__all__ = [
8128
"MLP",
29+
"Emulator",
8230
"EnsembleMLP",
8331
"EnsembleMLPDropout",
8432
"GaussianProcessCorrelatedMatern32",
@@ -89,6 +37,9 @@ def get_emulator_class(name: str) -> type[Emulator]:
8937
"PolynomialRegression",
9038
"RadialBasisFunctions",
9139
"RandomForest",
40+
"Registry",
9241
"SupportVectorMachine",
9342
"TransformedEmulator",
43+
"get_emulator_class",
44+
"register",
9445
]

autoemulate/emulators/gaussian_process/exact.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -476,18 +476,13 @@ def forward(self, x):
476476
return GaussianProcessLike(mean_x, covar_x)
477477

478478

479-
# GP registry to raise exception if duplicate created
480-
GP_REGISTRY = {
481-
"GaussianProcess": GaussianProcess,
482-
"GaussianProcessCorrelated": GaussianProcessCorrelated,
483-
}
484-
485-
486479
def create_gp_subclass(
487480
name: str,
488481
gp_base_class: type[GaussianProcess],
489482
covar_module_fn: CovarModuleFn,
490483
mean_module_fn: MeanModuleFn = constant_mean,
484+
auto_register: bool = True,
485+
overwrite: bool = True,
491486
**fixed_kwargs,
492487
) -> type[GaussianProcess]:
493488
"""
@@ -496,6 +491,9 @@ def create_gp_subclass(
496491
This function creates a subclass of GaussianProcess where certain parameters
497492
are fixed to specific values, reducing the parameter space for tuning.
498493
494+
The created subclass is automatically registered with the main emulator Registry
495+
(unless auto_register=False), making it discoverable by AutoEmulate.
496+
499497
Parameters
500498
----------
501499
name : str
@@ -506,6 +504,13 @@ def create_gp_subclass(
506504
Covariance module function to use in the subclass.
507505
mean_module_fn : MeanModuleFn
508506
Mean module function to use in the subclass. Defaults to `constant_mean`.
507+
auto_register : bool
508+
Whether to automatically register the created subclass with the main emulator
509+
Registry. Defaults to True.
510+
overwrite : bool
511+
Whether to allow overwriting an existing class with the same name in the
512+
main Registry. Useful for interactive development in notebooks. Defaults to
513+
True.
509514
**fixed_kwargs
510515
Keyword arguments to fix in the subclass. These parameters will be
511516
set to the provided values and excluded from hyperparameter tuning.
@@ -516,17 +521,21 @@ def create_gp_subclass(
516521
A new subclass of GaussianProcess with the specified parameters fixed.
517522
The returned class can be pickled and used like any other GP emulator.
518523
524+
Raises
525+
------
526+
ValueError
527+
If `name` matches `model_name()` or `short_name()` of an already registered
528+
emulator in the main Registry and `overwrite=False`.
529+
519530
Notes
520531
-----
521-
Fixed parameters are automatically excluded from `get_tune_params()` to
522-
prevent them from being included in hyperparameter optimization.
532+
- Fixed parameters are automatically excluded from `get_tune_params()` to prevent
533+
them from being included in hyperparameter optimization.
534+
- Pickling: The created subclass is registered in the caller's module namespace,
535+
ensuring it can be pickled and unpickled correctly even when created in downstream
536+
code that uses autoemulate as a dependency.
537+
- If auto_register=True (default), the class is also added to the main Registry.
523538
"""
524-
if name in GP_REGISTRY:
525-
raise ValueError(
526-
f"A GP class named '{name}' already exists. "
527-
f"Use a unique name or delete the existing class from GP_REGISTRY."
528-
)
529-
530539
standardize_x = fixed_kwargs.get("standardize_x", False)
531540
standardize_y = fixed_kwargs.get("standardize_y", True)
532541
fixed_mean_params = fixed_kwargs.get("fixed_mean_params", False)
@@ -614,30 +623,50 @@ def get_tune_params():
614623
model training and are not fixed.
615624
"""
616625

617-
# Set the provided name for the class
626+
# Determine the caller's module for proper pickling support.
627+
# When called from autoemulate itself, use __name__.
628+
# When called from user code, use the caller's module
629+
caller_frame = sys._getframe(1)
630+
caller_module_name = caller_frame.f_globals.get("__name__", __name__)
631+
632+
# Set the class name and module
618633
GaussianProcessSubclass.__name__ = name
619634
GaussianProcessSubclass.__qualname__ = name
620-
GaussianProcessSubclass.__module__ = __name__
635+
GaussianProcessSubclass.__module__ = caller_module_name
636+
637+
# Register class in the caller's module globals for pickling
638+
# This ensures the class can be pickled/unpickled correctly
639+
caller_frame.f_globals[name] = GaussianProcessSubclass
640+
641+
# Also register in the caller's module if it's a real module (not __main__)
642+
if caller_module_name in sys.modules and caller_module_name != "__main__":
643+
setattr(sys.modules[caller_module_name], name, GaussianProcessSubclass)
644+
645+
# Automatically register with the main emulator Registry if requested
646+
if auto_register:
647+
# Lazy import to avoid circular dependency with __init__.py
648+
from autoemulate.emulators import register # noqa: PLC0415
621649

622-
# Register class in the module's globals so can be pickled
623-
setattr(sys.modules[__name__], name, GaussianProcessSubclass)
624-
# Register subclass
625-
GP_REGISTRY[name] = GaussianProcessSubclass
650+
register(GaussianProcessSubclass, overwrite=overwrite)
626651

627652
return GaussianProcessSubclass
628653

629654

655+
# Built-in GP subclasses - auto_register=False as already registered in Registry init:
656+
# autoemulate/emulators/__init__.py
630657
GaussianProcessRBF = create_gp_subclass(
631658
"GaussianProcessRBF",
632659
GaussianProcess,
633660
covar_module_fn=rbf_kernel,
634661
mean_module_fn=constant_mean,
662+
auto_register=False,
635663
)
636664
GaussianProcessMatern32 = create_gp_subclass(
637665
"GaussianProcessMatern32",
638666
GaussianProcess,
639667
covar_module_fn=matern_3_2_kernel,
640668
mean_module_fn=constant_mean,
669+
auto_register=False,
641670
)
642671

643672
# correlated GP kernels
@@ -646,10 +675,12 @@ def get_tune_params():
646675
GaussianProcessCorrelated,
647676
covar_module_fn=rbf_kernel,
648677
mean_module_fn=constant_mean,
678+
auto_register=False,
649679
)
650680
GaussianProcessCorrelatedMatern32 = create_gp_subclass(
651681
"GaussianProcessCorrelatedMatern32",
652682
GaussianProcessCorrelated,
653683
covar_module_fn=matern_3_2_kernel,
654684
mean_module_fn=constant_mean,
685+
auto_register=False,
655686
)

0 commit comments

Comments
 (0)