Skip to content
58 changes: 54 additions & 4 deletions src/pybamm/dispatch/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import hashlib
import importlib.metadata
import textwrap
import urllib.request
from collections.abc import Callable, Mapping
from pathlib import Path

import pybamm
from pybamm.expression_tree.operations.serialise import Serialise


class EntryPoint(Mapping):
Expand Down Expand Up @@ -109,7 +115,28 @@ def __getattribute__(self, name):
models = EntryPoint(group="pybamm_models")


def Model(model: str, *args, **kwargs):
def get_cache_path(url):
cache_dir = Path.home() / ".pybamm_cache" / "pybamm" / "models"
cache_dir.mkdir(parents=True, exist_ok=True)
file_hash = hashlib.md5(url.encode()).hexdigest()
return cache_dir / f"{file_hash}.json"
Comment on lines +128 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why MD5 and not SHA-256? :)



def clear_model_cache():
cache_dir = Path.home() / ".pybamm_cache" / "pybamm" / "models"
if cache_dir.exists():
for file in cache_dir.glob("*.json"):
file.unlink()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled using platformdirs.user_cache_dir: https://platformdirs.readthedocs.io/en/latest/api.html#cache-directory



def Model(
model=None,
url=None,
battery_model=None,
force_download=False,
*args,
**kwargs,
):
Comment on lines +144 to +150
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have better typing here.

"""
Returns the loaded model object
Note: This feature is in its experimental phase.
Expand Down Expand Up @@ -137,6 +164,29 @@ def Model(model: str, *args, **kwargs):
>>> pybamm.Model('SPM') # doctest: +SKIP
<pybamm.models.full_battery_models.lithium_ion.spm.SPM object>
"""
model_class = models._get_class(model)

return model_class(*args, **kwargs)
if (model is None and url is None) or (model and url):
raise ValueError("You must provide exactly one of `model` or `url`.")

if url is not None:
if battery_model is None:
battery_model = pybamm.BaseModel()

cache_path = get_cache_path(url)
if not cache_path.exists() or force_download:
try:
print(f"Downloading model from {url}...")
urllib.request.urlretrieve(url, cache_path)
print(f"Model cached at: {cache_path}")
except Exception as e:
raise RuntimeError(f"Failed to download model from URL: {e}") from e
else:
print(f"Using cached model at: {cache_path}")

return Serialise.load_custom_model(str(cache_path), battery_model=battery_model)

if model is not None:
try:
model_class = models._get_class(model)
return model_class(*args, **kwargs)
except Exception as e:
raise ValueError(f"Could not load model '{model}': {e}") from e
Loading