Skip to content

Commit 8a499fb

Browse files
committed
Add tests
1 parent 25174d1 commit 8a499fb

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

autoemulate/emulators/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Callable
22
from typing import overload
33

4-
from torch import nn
4+
import torch
55

66
from .base import Emulator, GaussianProcessEmulator
77
from .ensemble import EnsembleMLP, EnsembleMLPDropout
@@ -154,9 +154,9 @@ def register_model(
154154
):
155155
self._gaussian_process_emulators.append(model_cls)
156156

157-
# Check if it's a PyTorch emulator (subclass of nn.Module) + not in PyTorch list
157+
# Check if it's a PyTorch emulator (subclass of torch.nn.Module)
158158
if (
159-
issubclass(model_cls, nn.Module)
159+
issubclass(model_cls, torch.nn.Module)
160160
and model_cls not in self._pytorch_emulators
161161
):
162162
self._pytorch_emulators.append(model_cls)
@@ -258,6 +258,10 @@ def get_emulator_class(name: str) -> type[Emulator]:
258258
def register(model_cls: type[Emulator]) -> type[Emulator]: ...
259259

260260

261+
@overload
262+
def register(model_cls: type[Emulator], *, overwrite: bool) -> type[Emulator]: ...
263+
264+
261265
@overload
262266
def register(
263267
*, overwrite: bool = True

tests/emulators/test_registry.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Tests for the Registry functionality."""
2+
3+
import pytest
4+
from autoemulate.emulators import Registry
5+
from autoemulate.emulators.base import Emulator, GaussianProcessEmulator
6+
7+
8+
def test_register_custom_emulator():
9+
"""Test registering a custom emulator subclass."""
10+
registry = Registry()
11+
12+
class TestEmulator(Emulator): ...
13+
14+
# Register
15+
registry.register_model(TestEmulator, overwrite=True)
16+
17+
# Check it was registered and can be retrieved
18+
assert TestEmulator in registry.all_emulators
19+
retrieved = registry.get_emulator_class("TestEmulator")
20+
assert retrieved == TestEmulator
21+
22+
23+
def test_register_gp_from_factory():
24+
"""Test registering a GP subclass created manually."""
25+
registry = Registry()
26+
27+
# Create a custom GP emulator class
28+
class TestGP(GaussianProcessEmulator): ...
29+
30+
# Register
31+
registry.register_model(TestGP, overwrite=True)
32+
33+
# Check it was registered in the correct lists
34+
assert TestGP in registry.all_emulators
35+
assert TestGP in registry.gaussian_process_emulators
36+
37+
38+
def test_overwrite_flag():
39+
"""Test that overwrite flag controls duplicate registration."""
40+
registry = Registry()
41+
42+
class TestOverwrite(Emulator):
43+
version = 1
44+
45+
@classmethod
46+
def model_name(cls):
47+
return "TestOverwrite"
48+
49+
# Register first version
50+
registry.register_model(TestOverwrite, overwrite=True)
51+
52+
# Create second version with same name
53+
class TestOverwrite2(Emulator):
54+
version = 2
55+
56+
@classmethod
57+
def model_name(cls):
58+
return "TestOverwrite"
59+
60+
# Overwrite should succeed
61+
registry.register_model(TestOverwrite2, overwrite=True)
62+
retrieved = registry.get_emulator_class("TestOverwrite")
63+
assert retrieved.version == 2 # type: ignore as version exists if test passes
64+
65+
# Overwrite=False should raise error
66+
class TestOverwrite3(Emulator):
67+
@classmethod
68+
def model_name(cls):
69+
return "TestOverwrite"
70+
71+
with pytest.raises(ValueError, match="already exists"):
72+
registry.register_model(TestOverwrite3, overwrite=False)

0 commit comments

Comments
 (0)