Skip to content
66 changes: 62 additions & 4 deletions src/pybamm/dispatch/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import hashlib
import importlib.metadata
import textwrap
import urllib.request
from collections.abc import Callable, Mapping
from pathlib import Path

from platformdirs import user_cache_dir

from pybamm.expression_tree.operations.serialise import Serialise

APP_NAME = "pybamm"
APP_AUTHOR = "pybamm"


class EntryPoint(Mapping):
Expand Down Expand Up @@ -109,7 +119,35 @@ def __getattribute__(self, name):
models = EntryPoint(group="pybamm_models")


def Model(model: str, *args, **kwargs):
def _get_cache_dir() -> Path:
cache_dir = Path(user_cache_dir(APP_NAME, APP_AUTHOR)) / "models"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir


def get_cache_path(url: str) -> Path:
cache_dir = _get_cache_dir()
file_hash = hashlib.md5(url.encode()).hexdigest()
return cache_dir / f"{file_hash}.json"


def clear_model_cache() -> None:
cache_dir = _get_cache_dir()
for file in cache_dir.glob("*.json"):
try:
file.unlink()
except Exception as e:
# Optional: log error instead of failing silently
print(f"Could not delete {file}: {e}")


def Model(
model=None,
url=None,
force_download=False,
*args,
**kwargs,
):
"""
Returns the loaded model object
Note: This feature is in its experimental phase.
Expand Down Expand Up @@ -137,6 +175,26 @@ def Model(model: str, *args, **kwargs):
>>> pybamm.Model('SPM') # doctest: +SKIP
<pybamm.models.full_battery_models.lithium_ion.spm.SPM object>
"""
model_class = models._get_class(model)

return model_class(*args, **kwargs)
if (model is None and url is None) or (model and url):
raise ValueError("You must provide exactly one of `model` or `url`.")

if url is not None:
cache_path = get_cache_path(url)
if not cache_path.exists() or force_download:
try:
print(f"Downloading model from {url}...")
urllib.request.urlretrieve(url, cache_path)
print(f"Model cached at: {cache_path}")
except Exception as e:
raise RuntimeError(f"Failed to download model from URL: {e}") from e
else:
print(f"Using cached model at: {cache_path}")

return Serialise.load_custom_model(str(cache_path))

if model is not None:
try:
model_class = models._get_class(model)
return model_class(*args, **kwargs)
except Exception as e:
raise ValueError(f"Could not load model '{model}': {e}") from e
122 changes: 121 additions & 1 deletion tests/unit/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
#
# Test dispatching mechanism in entry points
#
import pytest

import pybamm
from pybamm.dispatch.entry_points import (
_get_cache_dir,
clear_model_cache,
get_cache_path,
)

MODEL_URL = "https://raw.githubusercontent.com/pybamm-team/pybamm-reservoir-example/refs/heads/main/dfn.py"


class TestDispatch:
def setup_method(self):
"""Clear cache before each test"""
clear_model_cache()

def test_model_loads_through_entry_points(self):
"""Test that a model loaded through Model() function is actually functional"""
# Load model with build=False to avoid full initialization for faster testing
Expand All @@ -31,5 +44,112 @@ def test_model_function(self):

# Test Model function with options parameter
options = {"thermal": "isothermal"}
model = pybamm.Model("SPM", options)
model = pybamm.Model("SPM", options=options)
assert model.__class__.__name__ == "SPM"

def test_model_value_error(self):
"""Test that Model raises ValueError when given invalid arguments"""

# Neither model nor url provided
with pytest.raises(
ValueError, match="You must provide exactly one of `model` or `url`."
):
pybamm.Model()

# Both model and url provided
with pytest.raises(
ValueError, match="You must provide exactly one of `model` or `url`."
):
pybamm.Model(model="SPM", url="http://example.com/dfn.py")

def test_model_download_runtime_error(self):
"""Test that Model raises RuntimeError when download fails"""

bad_url = "h://example.invalid/model.json"

with pytest.raises(RuntimeError, match="Failed to download model from URL:"):
pybamm.Model(url=bad_url, force_download=True)

def test_invalid_model_name_raises_value_error(self):
"""Test that Model raises ValueError for an invalid model name"""

bad_model = "NonExistentModel123"

with pytest.raises(ValueError, match=f"Could not load model '{bad_model}':"):
pybamm.Model(model=bad_model)

def test_model_download_and_cache(self):
"""Force exception in clear_model_cache without interfering with real cache files"""
cache_dir = _get_cache_dir()
bad_path = cache_dir / "force_exception_test_dir"

try:
bad_path.mkdir(exist_ok=True) # not .json
try:
bad_path.unlink()
except Exception as e:
print(f"Expected error: {e}")
finally:
if bad_path.exists():
bad_path.rmdir()

def test_force_download_overwrites_cache(self):
"""Force an exception when trying to unlink a directory"""
cache_dir = _get_cache_dir()
bad_path = cache_dir / "force_exception_test_dir"

try:
bad_path.mkdir(exist_ok=True)
clear_model_cache()
try:
bad_path.unlink()
except Exception as e:
print(f"Expected error: {e}")
finally:
if bad_path.exists():
bad_path.rmdir()

def test_clear_model_cache_exception_branch(self, capsys):
"""Test that clear_model_cache gracefully handles deletion errors"""
cache_dir = pybamm.dispatch.entry_points._get_cache_dir()
cache_dir.mkdir(parents=True, exist_ok=True)

bad_path = cache_dir / "bad.json"
bad_path.mkdir(exist_ok=True)

try:
clear_model_cache()

captured = capsys.readouterr()
assert "Could not delete" in captured.out
assert "bad.json" in captured.out

assert bad_path.exists()
finally:
bad_path.rmdir()

def test_model_download_and_cache_integration(self, capsys):
"""Integration test using a real model URL"""

cache_path = get_cache_path(MODEL_URL)

# Clean up: if cache path exists as directory, remove it
if cache_path.exists():
if cache_path.is_dir():
import shutil

shutil.rmtree(cache_path)
else:
clear_model_cache()

# First call -> should download and print "Model cached at"
model = pybamm.Model(url=MODEL_URL, force_download=True)
captured = capsys.readouterr()
assert "Model cached at:" in captured.out
assert hasattr(model, "name")

# Second call -> should use cached file
model2 = pybamm.Model(url=MODEL_URL)
captured = capsys.readouterr()
assert "Using cached model at:" in captured.out
assert hasattr(model2, "name")
Loading