Skip to content

Commit 49898af

Browse files
committed
Refactor to registry submodule
1 parent 6175c15 commit 49898af

File tree

2 files changed

+289
-274
lines changed

2 files changed

+289
-274
lines changed

autoemulate/emulators/__init__.py

Lines changed: 3 additions & 274 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
from collections.abc import Callable
2-
from typing import overload
3-
4-
import torch
5-
6-
from .base import Emulator, GaussianProcessEmulator
1+
from .base import Emulator
72
from .ensemble import EnsembleMLP, EnsembleMLPDropout
83
from .gaussian_process.exact import (
94
GaussianProcessCorrelatedMatern32,
@@ -16,208 +11,10 @@
1611
from .polynomials import PolynomialRegression
1712
from .radial_basis_functions import RadialBasisFunctions
1813
from .random_forest import RandomForest
14+
from .registry import Registry, _default_registry, get_emulator_class, register
1915
from .svm import SupportVectorMachine
2016
from .transformed.base import TransformedEmulator
2117

22-
23-
class Registry:
24-
"""Registry for managing emulators.
25-
26-
The Registry class maintains collections of emulator classes organized by
27-
their properties (e.g., Gaussian Process emulators, PyTorch-based emulators).
28-
It provides methods to register new emulators and retrieve them by name.
29-
"""
30-
31-
def __init__(self):
32-
# Initialize the registry with default emulators, this is not updated
33-
self._default_emulators: list[type[Emulator]] = [
34-
GaussianProcessMatern32,
35-
GaussianProcessRBF,
36-
RadialBasisFunctions,
37-
PolynomialRegression,
38-
MLP,
39-
EnsembleMLP,
40-
]
41-
42-
self._non_pytorch_emulators: list[type[Emulator]] = [
43-
LightGBM,
44-
SupportVectorMachine,
45-
RandomForest,
46-
]
47-
48-
self._all_emulators: list[type[Emulator]] = [
49-
*self._default_emulators,
50-
*self._non_pytorch_emulators,
51-
GaussianProcessCorrelatedMatern32,
52-
GaussianProcessCorrelatedRBF,
53-
EnsembleMLPDropout,
54-
]
55-
56-
self._pytorch_emulators: list[type[Emulator]] = [
57-
emulator
58-
for emulator in self._all_emulators
59-
if emulator not in self._non_pytorch_emulators
60-
]
61-
self._gaussian_process_emulators: list[type[Emulator]] = [
62-
emulator
63-
for emulator in self._all_emulators
64-
if issubclass(emulator, GaussianProcessEmulator)
65-
]
66-
67-
self._emulator_registry = {
68-
em_cls.model_name().lower(): em_cls for em_cls in self._all_emulators
69-
}
70-
self._emulator_registry_short_name = {
71-
em_cls.short_name(): em_cls for em_cls in self._all_emulators
72-
}
73-
74-
def register_model(
75-
self, model_cls: type[Emulator], overwrite: bool = False
76-
) -> type[Emulator]:
77-
"""Register a new emulator model to the registry.
78-
79-
Parameters
80-
----------
81-
model_cls: type[Emulator]
82-
The emulator class to register.
83-
overwrite: bool
84-
If True, allows overwriting an existing model with the same name. If False,
85-
raises an error if a model with the same name already exists. Defaults to
86-
False.
87-
88-
Returns
89-
-------
90-
type[Emulator]
91-
The registered emulator class (unchanged).
92-
93-
Raises
94-
------
95-
ValueError
96-
If overwrite is False and a model with the same name already exists.
97-
98-
"""
99-
model_name = model_cls.model_name().lower()
100-
short_name = model_cls.short_name()
101-
102-
# Check if model already exists
103-
existing_cls_by_name = self._emulator_registry.get(model_name)
104-
existing_cls_by_short_name = self._emulator_registry_short_name.get(short_name)
105-
106-
if not overwrite:
107-
if existing_cls_by_name is not None:
108-
raise ValueError(
109-
f"Model with name '{model_name}' already exists. Set overwrite=True"
110-
f" to replace it."
111-
)
112-
if existing_cls_by_short_name is not None:
113-
raise ValueError(
114-
f"Model with short name '{short_name}' already exists. Set "
115-
f"overwrite=True to replace it."
116-
)
117-
118-
# If overwriting, remove the old model from all lists
119-
if (
120-
existing_cls_by_name is not None
121-
and existing_cls_by_name in self._all_emulators
122-
):
123-
self._all_emulators.remove(existing_cls_by_name)
124-
if existing_cls_by_name in self._gaussian_process_emulators:
125-
self._gaussian_process_emulators.remove(existing_cls_by_name)
126-
if existing_cls_by_name in self._pytorch_emulators:
127-
self._pytorch_emulators.remove(existing_cls_by_name)
128-
if existing_cls_by_name in self._non_pytorch_emulators:
129-
self._non_pytorch_emulators.remove(existing_cls_by_name)
130-
131-
# Add to all_emulators if not already present
132-
if model_cls not in self._all_emulators:
133-
self._all_emulators.append(model_cls)
134-
135-
# Update registries
136-
self._emulator_registry[model_name] = model_cls
137-
self._emulator_registry_short_name[short_name] = model_cls
138-
139-
# Add the gaussian process emulator list if a GaussianProcessEmulator subclass
140-
if (
141-
issubclass(model_cls, GaussianProcessEmulator)
142-
and model_cls not in self._gaussian_process_emulators
143-
):
144-
self._gaussian_process_emulators.append(model_cls)
145-
146-
# Check if it's a PyTorch emulator (subclass of torch.nn.Module) and it's not in
147-
# PyTorch list
148-
if (
149-
issubclass(model_cls, torch.nn.Module)
150-
and model_cls not in self._pytorch_emulators
151-
):
152-
self._pytorch_emulators.append(model_cls)
153-
# Check if not a PyTorch emulator and it's not in non-PyTorch list)
154-
if (
155-
not issubclass(model_cls, torch.nn.Module)
156-
and model_cls not in self._non_pytorch_emulators
157-
):
158-
self._non_pytorch_emulators.append(model_cls)
159-
160-
return model_cls
161-
162-
@property
163-
def gaussian_process_emulators(self) -> list[type[Emulator]]:
164-
"""Return the list of Gaussian Process emulators."""
165-
return self._gaussian_process_emulators
166-
167-
@property
168-
def pytorch_emulators(self) -> list[type[Emulator]]:
169-
"""Return the list of PyTorch-based emulators."""
170-
return self._pytorch_emulators
171-
172-
@property
173-
def all_emulators(self) -> list[type[Emulator]]:
174-
"""Return the list of all registered emulators."""
175-
return self._all_emulators
176-
177-
@property
178-
def non_pytorch_emulators(self) -> list[type[Emulator]]:
179-
"""Return the list of non-PyTorch emulators."""
180-
return self._non_pytorch_emulators
181-
182-
@property
183-
def default_emulators(self) -> list[type[Emulator]]:
184-
"""Return the list of default emulators."""
185-
return self._default_emulators
186-
187-
def get_emulator_class(self, name: str) -> type[Emulator]:
188-
"""Get the emulator class by name or short name.
189-
190-
Parameters
191-
----------
192-
name: str
193-
Either the name or short name of the emulator class.
194-
195-
Returns
196-
-------
197-
type[Emulator]
198-
The emulator class if found.
199-
200-
Raises
201-
------
202-
ValueError
203-
If the emulator name is not found.
204-
"""
205-
emulator_cls = self._emulator_registry.get(
206-
name.lower()
207-
) or self._emulator_registry_short_name.get(name.lower())
208-
209-
if emulator_cls is None:
210-
raise ValueError(
211-
f"Unknown emulator name: {name}."
212-
f"Available: {list(self._emulator_registry.keys())}"
213-
)
214-
215-
return emulator_cls
216-
217-
218-
# Create a default registry instance
219-
_default_registry = Registry()
220-
22118
# Module-level constants for backward compatibility and simplified public access
22219
DEFAULT_EMULATORS = _default_registry._default_emulators
22320
NON_PYTORCH_EMULATORS = _default_registry._non_pytorch_emulators
@@ -227,77 +24,9 @@ def get_emulator_class(self, name: str) -> type[Emulator]:
22724
EMULATOR_REGISTRY = _default_registry._emulator_registry
22825
EMULATOR_REGISTRY_SHORT_NAME = _default_registry._emulator_registry_short_name
22926

230-
231-
def get_emulator_class(name: str) -> type[Emulator]:
232-
"""Get the emulator class by name or short name using the default registry.
233-
234-
Parameters
235-
----------
236-
name: str
237-
The name or short name of the emulator class.
238-
239-
Returns
240-
-------
241-
type[Emulator]
242-
The emulator class if found.
243-
"""
244-
return _default_registry.get_emulator_class(name)
245-
246-
247-
# Overload signatures for type checking
248-
@overload
249-
def register(model_cls: type[Emulator]) -> type[Emulator]: ...
250-
251-
252-
@overload
253-
def register(model_cls: type[Emulator], *, overwrite: bool) -> type[Emulator]: ...
254-
255-
256-
@overload
257-
def register(*, overwrite: bool) -> Callable[[type[Emulator]], type[Emulator]]: ...
258-
259-
260-
# Actual implementation
261-
def register(
262-
model_cls: type[Emulator] | None = None, *, overwrite: bool = False
263-
) -> type[Emulator] | Callable[[type[Emulator]], type[Emulator]]:
264-
"""Register a new emulator model to the default registry.
265-
266-
Can be used as a function, a decorator without arguments, or a decorator with
267-
arguments.
268-
269-
Parameters
270-
----------
271-
model_cls: type[Emulator] | None
272-
The emulator class to register. If None, returns a decorator function.
273-
overwrite: bool
274-
If True, allows overwriting an existing model with the same name. If False,
275-
raises an error if a model with the same name already exists. Defaults to False.
276-
277-
Returns
278-
-------
279-
type[Emulator] | Callable[[type[Emulator]], type[Emulator]]
280-
The registered emulator class (unchanged) or a decorator function.
281-
282-
Raises
283-
------
284-
ValueError
285-
If overwrite is False and a model with the same name already exists.
286-
287-
"""
288-
289-
def decorator(cls: type[Emulator]) -> type[Emulator]:
290-
return _default_registry.register_model(cls, overwrite=overwrite)
291-
292-
if model_cls is None:
293-
# Called as @register(overwrite=...) or @register()
294-
return decorator
295-
# Called as @register or register(MyClass)
296-
return decorator(model_cls)
297-
298-
29927
__all__ = [
30028
"MLP",
29+
"Emulator",
30130
"EnsembleMLP",
30231
"EnsembleMLPDropout",
30332
"GaussianProcessCorrelatedMatern32",

0 commit comments

Comments
 (0)