@@ -476,18 +476,13 @@ def forward(self, x):
476476 return GaussianProcessLike (mean_x , covar_x )
477477
478478
479- # GP registry to raise exception if duplicate created
480- GP_REGISTRY = {
481- "GaussianProcess" : GaussianProcess ,
482- "GaussianProcessCorrelated" : GaussianProcessCorrelated ,
483- }
484-
485-
486479def create_gp_subclass (
487480 name : str ,
488481 gp_base_class : type [GaussianProcess ],
489482 covar_module_fn : CovarModuleFn ,
490483 mean_module_fn : MeanModuleFn = constant_mean ,
484+ auto_register : bool = True ,
485+ overwrite : bool = True ,
491486 ** fixed_kwargs ,
492487) -> type [GaussianProcess ]:
493488 """
@@ -496,6 +491,9 @@ def create_gp_subclass(
496491 This function creates a subclass of GaussianProcess where certain parameters
497492 are fixed to specific values, reducing the parameter space for tuning.
498493
494+ The created subclass is automatically registered with the main emulator Registry
495+ (unless auto_register=False), making it discoverable by AutoEmulate.
496+
499497 Parameters
500498 ----------
501499 name : str
@@ -506,6 +504,13 @@ def create_gp_subclass(
506504 Covariance module function to use in the subclass.
507505 mean_module_fn : MeanModuleFn
508506 Mean module function to use in the subclass. Defaults to `constant_mean`.
507+ auto_register : bool
508+ Whether to automatically register the created subclass with the main emulator
509+ Registry. Defaults to True.
510+ overwrite : bool
511+ Whether to allow overwriting an existing class with the same name in the
512+ main Registry. Useful for interactive development in notebooks. Defaults to
513+ True.
509514 **fixed_kwargs
510515 Keyword arguments to fix in the subclass. These parameters will be
511516 set to the provided values and excluded from hyperparameter tuning.
@@ -516,17 +521,21 @@ def create_gp_subclass(
516521 A new subclass of GaussianProcess with the specified parameters fixed.
517522 The returned class can be pickled and used like any other GP emulator.
518523
524+ Raises
525+ ------
526+ ValueError
527+ If `name` matches `model_name()` or `short_name()` of an already registered
528+ emulator in the main Registry and `overwrite=False`.
529+
519530 Notes
520531 -----
521- Fixed parameters are automatically excluded from `get_tune_params()` to
522- prevent them from being included in hyperparameter optimization.
532+ - Fixed parameters are automatically excluded from `get_tune_params()` to prevent
533+ them from being included in hyperparameter optimization.
534+ - Pickling: The created subclass is registered in the caller's module namespace,
535+ ensuring it can be pickled and unpickled correctly even when created in downstream
536+ code that uses autoemulate as a dependency.
537+ - If auto_register=True (default), the class is also added to the main Registry.
523538 """
524- if name in GP_REGISTRY :
525- raise ValueError (
526- f"A GP class named '{ name } ' already exists. "
527- f"Use a unique name or delete the existing class from GP_REGISTRY."
528- )
529-
530539 standardize_x = fixed_kwargs .get ("standardize_x" , False )
531540 standardize_y = fixed_kwargs .get ("standardize_y" , True )
532541 fixed_mean_params = fixed_kwargs .get ("fixed_mean_params" , False )
@@ -614,30 +623,50 @@ def get_tune_params():
614623 model training and are not fixed.
615624 """
616625
617- # Set the provided name for the class
626+ # Determine the caller's module for proper pickling support.
627+ # When called from autoemulate itself, use __name__.
628+ # When called from user code, use the caller's module
629+ caller_frame = sys ._getframe (1 )
630+ caller_module_name = caller_frame .f_globals .get ("__name__" , __name__ )
631+
632+ # Set the class name and module
618633 GaussianProcessSubclass .__name__ = name
619634 GaussianProcessSubclass .__qualname__ = name
620- GaussianProcessSubclass .__module__ = __name__
635+ GaussianProcessSubclass .__module__ = caller_module_name
636+
637+ # Register class in the caller's module globals for pickling
638+ # This ensures the class can be pickled/unpickled correctly
639+ caller_frame .f_globals [name ] = GaussianProcessSubclass
640+
641+ # Also register in the caller's module if it's a real module (not __main__)
642+ if caller_module_name in sys .modules and caller_module_name != "__main__" :
643+ setattr (sys .modules [caller_module_name ], name , GaussianProcessSubclass )
644+
645+ # Automatically register with the main emulator Registry if requested
646+ if auto_register :
647+ # Lazy import to avoid circular dependency with __init__.py
648+ from autoemulate .emulators import register # noqa: PLC0415
621649
622- # Register class in the module's globals so can be pickled
623- setattr (sys .modules [__name__ ], name , GaussianProcessSubclass )
624- # Register subclass
625- GP_REGISTRY [name ] = GaussianProcessSubclass
650+ register (GaussianProcessSubclass , overwrite = overwrite )
626651
627652 return GaussianProcessSubclass
628653
629654
655+ # Built-in GP subclasses - auto_register=False as already registered in Registry init:
656+ # autoemulate/emulators/__init__.py
630657GaussianProcessRBF = create_gp_subclass (
631658 "GaussianProcessRBF" ,
632659 GaussianProcess ,
633660 covar_module_fn = rbf_kernel ,
634661 mean_module_fn = constant_mean ,
662+ auto_register = False ,
635663)
636664GaussianProcessMatern32 = create_gp_subclass (
637665 "GaussianProcessMatern32" ,
638666 GaussianProcess ,
639667 covar_module_fn = matern_3_2_kernel ,
640668 mean_module_fn = constant_mean ,
669+ auto_register = False ,
641670)
642671
643672# correlated GP kernels
@@ -646,10 +675,12 @@ def get_tune_params():
646675 GaussianProcessCorrelated ,
647676 covar_module_fn = rbf_kernel ,
648677 mean_module_fn = constant_mean ,
678+ auto_register = False ,
649679)
650680GaussianProcessCorrelatedMatern32 = create_gp_subclass (
651681 "GaussianProcessCorrelatedMatern32" ,
652682 GaussianProcessCorrelated ,
653683 covar_module_fn = matern_3_2_kernel ,
654684 mean_module_fn = constant_mean ,
685+ auto_register = False ,
655686)
0 commit comments