Conversation
Implement __getattr__-based lazy loading in __init__.py with a hybrid approach that eagerly imports expression_tree modules (accessed thousands of times during model building) while lazily loading heavy dependencies. Performance improvements: - Fresh/cold import: ~11x faster (1.9s vs 21s) - Cached import: ~13% faster (0.59s vs 0.68s) - Subsequent solves within same script: identical performance
- Create single source of truth for SUBMODULE_ALIASES in _lazy_config.py
- Standardize external dependencies on lazy.load() at module level:
- interpolant.py: scipy.interpolate
- serialise.py: black
- processed_variable.py: xarray
- processed_variable_computed.py: scipy.integrate
- Consolidate JAX x64 configuration in __init__.py (removed duplicates
from evaluate_python.py and jax_bdf_solver.py)
- Add CI check for stub file freshness in test_on_push.yml
Benchmarking showed that lazy.load() adds significant overhead due to proxy object creation and attribute interception. Inline imports are faster for both cold and warm starts: - Cold start total: 13.7s (inline) vs 15.1s (lazy.load) - 10% faster - Warm start total: 1.06s (inline) vs 1.18s (lazy.load) - 10% faster - Access lithium_ion: 0.003s (inline) vs 0.094s (lazy.load) - 31x faster
- Remove eager JAX import from __init__.py that defeated lazy loading
- Add get_jax() helper in util.py that lazily imports and configures JAX
for float64 precision on first use (except Metal backend)
- Update JAX-using files to use pybamm.get_jax() instead of import jax
- Add --check-exports flag to stub generator that detects missing exports
by comparing stub config against modules' __all__ lists
- Update CI to run --check-exports alongside --validate
- Add TestJaxConfiguration test class for the new get_jax() function
- Add missing 'models' export from dispatch module
Move EAGER_IMPORTS and LAZY_IMPORTS configuration from generate_pyi_stub.py into _lazy_config.py, creating a single source of truth for all import configuration.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5364 +/- ##
==========================================
- Coverage 98.03% 97.97% -0.06%
==========================================
Files 327 328 +1
Lines 28688 28655 -33
==========================================
- Hits 28124 28076 -48
- Misses 564 579 +15 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
agriyakhetarpal
left a comment
There was a problem hiding this comment.
Hi @BradyPlanden, wow, this is a lot of code for optimising our import time! At the same time, the numbers speak for themselves, so it's nice to see this. I've added a few comments below. It is very difficult to do a code review this large in one sitting, so I'm sure I might have missed something though.
Also, I notice that there are still a few eager imports that you noted as intentional. That kind of undermines the lazy import strategy and (I haven't tracked this) undoes some of the improvements. I wonder if we could do a warm-up on the first solve when we create a simulation instead of those?
A few more things:
- The
.pyifiles must be distributed both in the sdist and the wheels; forhatchling, these are not handled by default, but I do not remember closely. We will need to add them to[tool.hatch.build.targets.sdist]and so on. - Considering the size of this change, could you please add something about it to the contributing guide as appropriate?
| - name: Check style | ||
| run: uvx pre-commit run -a | ||
|
|
||
| - name: Check stub file is up to date |
There was a problem hiding this comment.
Could you please add this into the .pre-commit-config.yaml file directly as a pre-commit hook? This way we won't need to add another step here and it will be covered by uvx pre-commit run -a above:
- id: stub-file-validation
name: Check that the stub file is up to date
entry: python scripts/generate_pyi_stub.py
language: python
files: ^src/pybamm/*.py$
pass_filenames: trueSomething like this should work (please validate the regex here).
| # Add src to path for imports | ||
| REPO_ROOT = Path(__file__).parent.parent | ||
| sys.path.insert(0, str(REPO_ROOT / "src")) |
There was a problem hiding this comment.
This should happen before the pybamm._lazy_config import line so it picks up the latest code, right? Could you reorder them?
| self._check_for_bibtex() | ||
| # Dict mapping citations keys to BibTex entries | ||
| self._all_citations: dict[str, str] = dict() | ||
| self._all_citations: dict = dict() |
There was a problem hiding this comment.
I think this should be dict[str, Entry], so this is a slight regression in type annotations
|
|
||
| # dispatch | ||
| from .dispatch import parameter_sets as parameter_sets | ||
| from .dispatch import models as models |
There was a problem hiding this comment.
This was not in the __init__.py file, could you please double check?
| # Lazily initialized posthog client | ||
| _posthog = None | ||
| _disabled = False | ||
|
|
||
|
|
||
| class MockTelemetry: | ||
| def __init__(self): | ||
| self.disabled = True | ||
| def _get_posthog(): | ||
| """Lazily initialize the posthog client on first use.""" | ||
| global _posthog, _disabled |
There was a problem hiding this comment.
This looks like it would induce a race pretty easily. Two threads calling _get_posthog() simultaneously would both see _posthog is None and both attempt initialisation. This is the same for _jax_configured and get_jax. Could you add a lock here? We previously had an issue about something like this back when telemetry was recently implemented and we are working around it by PYBAMM_DISABLE_TELEMETRY. Also, it doesn't look like the thread safety tests are reaching these.
| if platform != "metal": | ||
| assert jax.config.x64_enabled, "JAX x64 should be enabled after get_jax()" | ||
|
|
||
| @pytest.mark.skipif(not __import__("pybamm").has_jax(), reason="JAX not installed") |
There was a problem hiding this comment.
This (and the rest of the JAX tests) can simply use pytest.importorskip(): https://docs.pytest.org/en/stable/reference/reference.html#pytest-importorskip
(otherwise we import pybamm and has_jax which will slow down the test collection)
| for module_path, attrs in EAGER_IMPORTS.items(): | ||
| module_name = module_path.split(".")[-1] | ||
| lines.append(f"# {module_name}") | ||
| for attr in attrs: | ||
| lines.append(f"from {module_path} import {attr} as {attr}") | ||
| lines.append("") |
There was a problem hiding this comment.
We can make the order deterministic and that'll help reduce PR diffs a bit:
| for module_path, attrs in EAGER_IMPORTS.items(): | |
| module_name = module_path.split(".")[-1] | |
| lines.append(f"# {module_name}") | |
| for attr in attrs: | |
| lines.append(f"from {module_path} import {attr} as {attr}") | |
| lines.append("") | |
| for module_path in sorted(EAGER_IMPORTS): | |
| attrs = EAGER_IMPORTS[module_path] | |
| module_name = module_path.split(".")[-1] | |
| lines.append(f"# {module_name}") | |
| for attr in attrs: | |
| lines.append(f"from {module_path} import {attr} as {attr}") | |
| lines.append("") |
| for module_path, attrs in LAZY_IMPORTS.items(): | ||
| module_name = module_path.split(".")[-1] | ||
| lines.append(f"# {module_name}") | ||
| for attr in attrs: | ||
| lines.append(f"from {module_path} import {attr} as {attr}") | ||
| lines.append("") |
There was a problem hiding this comment.
Same here:
| for module_path, attrs in LAZY_IMPORTS.items(): | |
| module_name = module_path.split(".")[-1] | |
| lines.append(f"# {module_name}") | |
| for attr in attrs: | |
| lines.append(f"from {module_path} import {attr} as {attr}") | |
| lines.append("") | |
| for module_path in sorted(LAZY_IMPORTS): | |
| attrs = LAZY_IMPORTS[module_path] | |
| module_name = module_path.split(".")[-1] | |
| lines.append(f"# {module_name}") | |
| for attr in attrs: | |
| lines.append(f"from {module_path} import {attr} as {attr}") | |
| lines.append("") |
Description
The aim of this PR is to improve the time-to-first-solve performance of cold and warm starts. To achieve this, a lazy import structure closely following SPEC001 from Scientific-Python has been added for internal submodules, and the management of external imports has been updated.
Here is the performance difference on my M4 Pro machine for time-to-first-solve of the standard DFN.
There should not be any breaking changes to the submodule API for end-users. However, the documentation does require updating full paths.
An alternative implementation was presented in #3732; however, this approach greatly simplifies the added code and shouldn't cause massive challenges to developers.
Fixes # (issue)
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)
Important checks:
Please confirm the following before marking the PR as ready for review:
nox -s pre-commitnox -s testsnox -s doctests