1- from collections .abc import Callable
2- from typing import overload
3-
4- import torch
5-
6- from .base import Emulator , GaussianProcessEmulator
1+ from .base import Emulator
72from .ensemble import EnsembleMLP , EnsembleMLPDropout
83from .gaussian_process .exact import (
94 GaussianProcessCorrelatedMatern32 ,
1611from .polynomials import PolynomialRegression
1712from .radial_basis_functions import RadialBasisFunctions
1813from .random_forest import RandomForest
14+ from .registry import Registry , _default_registry , get_emulator_class , register
1915from .svm import SupportVectorMachine
2016from .transformed .base import TransformedEmulator
2117
22-
23- class Registry :
24- """Registry for managing emulators.
25-
26- The Registry class maintains collections of emulator classes organized by
27- their properties (e.g., Gaussian Process emulators, PyTorch-based emulators).
28- It provides methods to register new emulators and retrieve them by name.
29- """
30-
31- def __init__ (self ):
32- # Initialize the registry with default emulators, this is not updated
33- self ._default_emulators : list [type [Emulator ]] = [
34- GaussianProcessMatern32 ,
35- GaussianProcessRBF ,
36- RadialBasisFunctions ,
37- PolynomialRegression ,
38- MLP ,
39- EnsembleMLP ,
40- ]
41-
42- self ._non_pytorch_emulators : list [type [Emulator ]] = [
43- LightGBM ,
44- SupportVectorMachine ,
45- RandomForest ,
46- ]
47-
48- self ._all_emulators : list [type [Emulator ]] = [
49- * self ._default_emulators ,
50- * self ._non_pytorch_emulators ,
51- GaussianProcessCorrelatedMatern32 ,
52- GaussianProcessCorrelatedRBF ,
53- EnsembleMLPDropout ,
54- ]
55-
56- self ._pytorch_emulators : list [type [Emulator ]] = [
57- emulator
58- for emulator in self ._all_emulators
59- if emulator not in self ._non_pytorch_emulators
60- ]
61- self ._gaussian_process_emulators : list [type [Emulator ]] = [
62- emulator
63- for emulator in self ._all_emulators
64- if issubclass (emulator , GaussianProcessEmulator )
65- ]
66-
67- self ._emulator_registry = {
68- em_cls .model_name ().lower (): em_cls for em_cls in self ._all_emulators
69- }
70- self ._emulator_registry_short_name = {
71- em_cls .short_name (): em_cls for em_cls in self ._all_emulators
72- }
73-
74- def register_model (
75- self , model_cls : type [Emulator ], overwrite : bool = False
76- ) -> type [Emulator ]:
77- """Register a new emulator model to the registry.
78-
79- Parameters
80- ----------
81- model_cls: type[Emulator]
82- The emulator class to register.
83- overwrite: bool
84- If True, allows overwriting an existing model with the same name. If False,
85- raises an error if a model with the same name already exists. Defaults to
86- False.
87-
88- Returns
89- -------
90- type[Emulator]
91- The registered emulator class (unchanged).
92-
93- Raises
94- ------
95- ValueError
96- If overwrite is False and a model with the same name already exists.
97-
98- """
99- model_name = model_cls .model_name ().lower ()
100- short_name = model_cls .short_name ()
101-
102- # Check if model already exists
103- existing_cls_by_name = self ._emulator_registry .get (model_name )
104- existing_cls_by_short_name = self ._emulator_registry_short_name .get (short_name )
105-
106- if not overwrite :
107- if existing_cls_by_name is not None :
108- raise ValueError (
109- f"Model with name '{ model_name } ' already exists. Set overwrite=True"
110- f" to replace it."
111- )
112- if existing_cls_by_short_name is not None :
113- raise ValueError (
114- f"Model with short name '{ short_name } ' already exists. Set "
115- f"overwrite=True to replace it."
116- )
117-
118- # If overwriting, remove the old model from all lists
119- if (
120- existing_cls_by_name is not None
121- and existing_cls_by_name in self ._all_emulators
122- ):
123- self ._all_emulators .remove (existing_cls_by_name )
124- if existing_cls_by_name in self ._gaussian_process_emulators :
125- self ._gaussian_process_emulators .remove (existing_cls_by_name )
126- if existing_cls_by_name in self ._pytorch_emulators :
127- self ._pytorch_emulators .remove (existing_cls_by_name )
128- if existing_cls_by_name in self ._non_pytorch_emulators :
129- self ._non_pytorch_emulators .remove (existing_cls_by_name )
130-
131- # Add to all_emulators if not already present
132- if model_cls not in self ._all_emulators :
133- self ._all_emulators .append (model_cls )
134-
135- # Update registries
136- self ._emulator_registry [model_name ] = model_cls
137- self ._emulator_registry_short_name [short_name ] = model_cls
138-
139- # Add the gaussian process emulator list if a GaussianProcessEmulator subclass
140- if (
141- issubclass (model_cls , GaussianProcessEmulator )
142- and model_cls not in self ._gaussian_process_emulators
143- ):
144- self ._gaussian_process_emulators .append (model_cls )
145-
146- # Check if it's a PyTorch emulator (subclass of torch.nn.Module) and it's not in
147- # PyTorch list
148- if (
149- issubclass (model_cls , torch .nn .Module )
150- and model_cls not in self ._pytorch_emulators
151- ):
152- self ._pytorch_emulators .append (model_cls )
153- # Check if not a PyTorch emulator and it's not in non-PyTorch list)
154- if (
155- not issubclass (model_cls , torch .nn .Module )
156- and model_cls not in self ._non_pytorch_emulators
157- ):
158- self ._non_pytorch_emulators .append (model_cls )
159-
160- return model_cls
161-
162- @property
163- def gaussian_process_emulators (self ) -> list [type [Emulator ]]:
164- """Return the list of Gaussian Process emulators."""
165- return self ._gaussian_process_emulators
166-
167- @property
168- def pytorch_emulators (self ) -> list [type [Emulator ]]:
169- """Return the list of PyTorch-based emulators."""
170- return self ._pytorch_emulators
171-
172- @property
173- def all_emulators (self ) -> list [type [Emulator ]]:
174- """Return the list of all registered emulators."""
175- return self ._all_emulators
176-
177- @property
178- def non_pytorch_emulators (self ) -> list [type [Emulator ]]:
179- """Return the list of non-PyTorch emulators."""
180- return self ._non_pytorch_emulators
181-
182- @property
183- def default_emulators (self ) -> list [type [Emulator ]]:
184- """Return the list of default emulators."""
185- return self ._default_emulators
186-
187- def get_emulator_class (self , name : str ) -> type [Emulator ]:
188- """Get the emulator class by name or short name.
189-
190- Parameters
191- ----------
192- name: str
193- Either the name or short name of the emulator class.
194-
195- Returns
196- -------
197- type[Emulator]
198- The emulator class if found.
199-
200- Raises
201- ------
202- ValueError
203- If the emulator name is not found.
204- """
205- emulator_cls = self ._emulator_registry .get (
206- name .lower ()
207- ) or self ._emulator_registry_short_name .get (name .lower ())
208-
209- if emulator_cls is None :
210- raise ValueError (
211- f"Unknown emulator name: { name } ."
212- f"Available: { list (self ._emulator_registry .keys ())} "
213- )
214-
215- return emulator_cls
216-
217-
218- # Create a default registry instance
219- _default_registry = Registry ()
220-
22118# Module-level constants for backward compatibility and simplified public access
22219DEFAULT_EMULATORS = _default_registry ._default_emulators
22320NON_PYTORCH_EMULATORS = _default_registry ._non_pytorch_emulators
@@ -227,77 +24,9 @@ def get_emulator_class(self, name: str) -> type[Emulator]:
22724EMULATOR_REGISTRY = _default_registry ._emulator_registry
22825EMULATOR_REGISTRY_SHORT_NAME = _default_registry ._emulator_registry_short_name
22926
230-
231- def get_emulator_class (name : str ) -> type [Emulator ]:
232- """Get the emulator class by name or short name using the default registry.
233-
234- Parameters
235- ----------
236- name: str
237- The name or short name of the emulator class.
238-
239- Returns
240- -------
241- type[Emulator]
242- The emulator class if found.
243- """
244- return _default_registry .get_emulator_class (name )
245-
246-
247- # Overload signatures for type checking
248- @overload
249- def register (model_cls : type [Emulator ]) -> type [Emulator ]: ...
250-
251-
252- @overload
253- def register (model_cls : type [Emulator ], * , overwrite : bool ) -> type [Emulator ]: ...
254-
255-
256- @overload
257- def register (* , overwrite : bool ) -> Callable [[type [Emulator ]], type [Emulator ]]: ...
258-
259-
260- # Actual implementation
261- def register (
262- model_cls : type [Emulator ] | None = None , * , overwrite : bool = False
263- ) -> type [Emulator ] | Callable [[type [Emulator ]], type [Emulator ]]:
264- """Register a new emulator model to the default registry.
265-
266- Can be used as a function, a decorator without arguments, or a decorator with
267- arguments.
268-
269- Parameters
270- ----------
271- model_cls: type[Emulator] | None
272- The emulator class to register. If None, returns a decorator function.
273- overwrite: bool
274- If True, allows overwriting an existing model with the same name. If False,
275- raises an error if a model with the same name already exists. Defaults to False.
276-
277- Returns
278- -------
279- type[Emulator] | Callable[[type[Emulator]], type[Emulator]]
280- The registered emulator class (unchanged) or a decorator function.
281-
282- Raises
283- ------
284- ValueError
285- If overwrite is False and a model with the same name already exists.
286-
287- """
288-
289- def decorator (cls : type [Emulator ]) -> type [Emulator ]:
290- return _default_registry .register_model (cls , overwrite = overwrite )
291-
292- if model_cls is None :
293- # Called as @register(overwrite=...) or @register()
294- return decorator
295- # Called as @register or register(MyClass)
296- return decorator (model_cls )
297-
298-
29927__all__ = [
30028 "MLP" ,
29+ "Emulator" ,
30130 "EnsembleMLP" ,
30231 "EnsembleMLPDropout" ,
30332 "GaussianProcessCorrelatedMatern32" ,
0 commit comments