-
-
Notifications
You must be signed in to change notification settings - Fork 738
[GSoC 2025] Load model JSON files from URLs #5137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
f55a0a9
0b3d671
6ee08b0
5112f8b
ed0c714
3e3658b
fc437a1
00ff44f
4214681
71ae0c4
51bf407
c39d1c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,12 @@ | ||
| import hashlib | ||
| import importlib.metadata | ||
| import textwrap | ||
| import urllib.request | ||
| from collections.abc import Callable, Mapping | ||
| from pathlib import Path | ||
|
|
||
| import pybamm | ||
| from pybamm.expression_tree.operations.serialise import Serialise | ||
|
|
||
|
|
||
| class EntryPoint(Mapping): | ||
|
|
@@ -109,7 +115,28 @@ def __getattribute__(self, name): | |
| models = EntryPoint(group="pybamm_models") | ||
|
|
||
|
|
||
| def Model(model: str, *args, **kwargs): | ||
| def get_cache_path(url): | ||
| cache_dir = Path.home() / ".pybamm_cache" / "pybamm" / "models" | ||
| cache_dir.mkdir(parents=True, exist_ok=True) | ||
| file_hash = hashlib.md5(url.encode()).hexdigest() | ||
| return cache_dir / f"{file_hash}.json" | ||
|
|
||
|
|
||
| def clear_model_cache(): | ||
| cache_dir = Path.home() / ".pybamm_cache" / "pybamm" / "models" | ||
| if cache_dir.exists(): | ||
| for file in cache_dir.glob("*.json"): | ||
| file.unlink() | ||
|
||
|
|
||
|
|
||
| def Model( | ||
| model=None, | ||
| url=None, | ||
| battery_model=None, | ||
| force_download=False, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
|
Comment on lines
+144
to
+150
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can have better typing here. |
||
| """ | ||
| Returns the loaded model object | ||
| Note: This feature is in its experimental phase. | ||
|
|
@@ -137,6 +164,29 @@ def Model(model: str, *args, **kwargs): | |
| >>> pybamm.Model('SPM') # doctest: +SKIP | ||
| <pybamm.models.full_battery_models.lithium_ion.spm.SPM object> | ||
| """ | ||
| model_class = models._get_class(model) | ||
|
|
||
| return model_class(*args, **kwargs) | ||
| if (model is None and url is None) or (model and url): | ||
| raise ValueError("You must provide exactly one of `model` or `url`.") | ||
|
|
||
| if url is not None: | ||
| if battery_model is None: | ||
| battery_model = pybamm.BaseModel() | ||
|
|
||
| cache_path = get_cache_path(url) | ||
| if not cache_path.exists() or force_download: | ||
| try: | ||
| print(f"Downloading model from {url}...") | ||
| urllib.request.urlretrieve(url, cache_path) | ||
| print(f"Model cached at: {cache_path}") | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to download model from URL: {e}") from e | ||
| else: | ||
| print(f"Using cached model at: {cache_path}") | ||
|
|
||
| return Serialise.load_custom_model(str(cache_path), battery_model=battery_model) | ||
|
|
||
| if model is not None: | ||
| try: | ||
| model_class = models._get_class(model) | ||
| return model_class(*args, **kwargs) | ||
| except Exception as e: | ||
| raise ValueError(f"Could not load model '{model}': {e}") from e | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why MD5 and not SHA-256? :)