Skip to content

Improve time-to-first-solve#5364

Open
BradyPlanden wants to merge 19 commits intomainfrom
feat/improvements-for-time-to-first-solve
Open

Improve time-to-first-solve#5364
BradyPlanden wants to merge 19 commits intomainfrom
feat/improvements-for-time-to-first-solve

Conversation

@BradyPlanden
Copy link
Member

@BradyPlanden BradyPlanden commented Jan 26, 2026

Description

The aim of this PR is to improve the time-to-first-solve performance of cold and warm starts. To achieve this, a lazy import structure closely following SPEC001 from Scientific-Python has been added for internal submodules, and the management of external imports has been updated.

Here is the performance difference on my M4 Pro machine for time-to-first-solve of the standard DFN.

Metric Main Local Speedup
Warm start import 0.667s 0.336s 1.98x
Warm start total 0.787s 0.500s 1.57x
Cold start import 3.289s 1.849s 1.78x
Cold start total 3.461s 2.246s 1.54x
Fresh env total 38.041s 18.398s 2.07x

There should not be any breaking changes to the submodule API for end-users. However, the documentation does require updating full paths.

An alternative implementation was presented in #3732; however, this approach greatly simplifies the added code and shouldn't cause massive challenges to developers.

Fixes # (issue)

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)

Important checks:

Please confirm the following before marking the PR as ready for review:

  • No style issues: nox -s pre-commit
  • All tests pass: nox -s tests
  • The documentation builds: nox -s doctests
  • Code is commented for hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

Implement __getattr__-based lazy loading in __init__.py with a hybrid approach that eagerly imports expression_tree modules (accessed thousands of times during model building) while lazily loading heavy dependencies.

Performance improvements:
 - Fresh/cold import: ~11x faster (1.9s vs 21s)
 - Cached import: ~13% faster (0.59s vs 0.68s)
 - Subsequent solves within same script: identical performance
   - Create single source of truth for SUBMODULE_ALIASES in _lazy_config.py
   - Standardize external dependencies on lazy.load() at module level:
     - interpolant.py: scipy.interpolate
     - serialise.py: black
     - processed_variable.py: xarray
     - processed_variable_computed.py: scipy.integrate
   - Consolidate JAX x64 configuration in __init__.py (removed duplicates
     from evaluate_python.py and jax_bdf_solver.py)
   - Add CI check for stub file freshness in test_on_push.yml
   Benchmarking showed that lazy.load() adds significant overhead due to
   proxy object creation and attribute interception. Inline imports are
   faster for both cold and warm starts:

   - Cold start total: 13.7s (inline) vs 15.1s (lazy.load) - 10% faster
   - Warm start total: 1.06s (inline) vs 1.18s (lazy.load) - 10% faster
   - Access lithium_ion: 0.003s (inline) vs 0.094s (lazy.load) - 31x faster
   - Remove eager JAX import from __init__.py that defeated lazy loading
   - Add get_jax() helper in util.py that lazily imports and configures JAX
     for float64 precision on first use (except Metal backend)
   - Update JAX-using files to use pybamm.get_jax() instead of import jax
   - Add --check-exports flag to stub generator that detects missing exports
     by comparing stub config against modules' __all__ lists
   - Update CI to run --check-exports alongside --validate
   - Add TestJaxConfiguration test class for the new get_jax() function
   - Add missing 'models' export from dispatch module
   Move EAGER_IMPORTS and LAZY_IMPORTS configuration from
   generate_pyi_stub.py into _lazy_config.py, creating a single
   source of truth for all import configuration.
@BradyPlanden BradyPlanden requested a review from a team as a code owner January 26, 2026 09:47
@codecov
Copy link

codecov bot commented Jan 26, 2026

Codecov Report

❌ Patch coverage is 88.00000% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.97%. Comparing base (5741496) to head (25fe12f).

Files with missing lines Patch % Lines
src/pybamm/telemetry.py 38.46% 8 Missing ⚠️
src/pybamm/util.py 78.57% 3 Missing ⚠️
src/pybamm/solvers/processed_variable_computed.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5364      +/-   ##
==========================================
- Coverage   98.03%   97.97%   -0.06%     
==========================================
  Files         327      328       +1     
  Lines       28688    28655      -33     
==========================================
- Hits        28124    28076      -48     
- Misses        564      579      +15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@agriyakhetarpal agriyakhetarpal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BradyPlanden, wow, this is a lot of code for optimising our import time! At the same time, the numbers speak for themselves, so it's nice to see this. I've added a few comments below. It is very difficult to do a code review this large in one sitting, so I'm sure I might have missed something though.

Also, I notice that there are still a few eager imports that you noted as intentional. That kind of undermines the lazy import strategy and (I haven't tracked this) undoes some of the improvements. I wonder if we could do a warm-up on the first solve when we create a simulation instead of those?

A few more things:

  • The .pyi files must be distributed both in the sdist and the wheels; for hatchling, these are not handled by default, but I do not remember closely. We will need to add them to [tool.hatch.build.targets.sdist] and so on.
  • Considering the size of this change, could you please add something about it to the contributing guide as appropriate?

- name: Check style
run: uvx pre-commit run -a

- name: Check stub file is up to date
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add this into the .pre-commit-config.yaml file directly as a pre-commit hook? This way we won't need to add another step here and it will be covered by uvx pre-commit run -a above:

- id: stub-file-validation
  name: Check that the stub file is up to date
  entry: python scripts/generate_pyi_stub.py
  language: python
  files: ^src/pybamm/*.py$
  pass_filenames: true

Something like this should work (please validate the regex here).

Comment on lines +32 to +34
# Add src to path for imports
REPO_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(REPO_ROOT / "src"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should happen before the pybamm._lazy_config import line so it picks up the latest code, right? Could you reorder them?

self._check_for_bibtex()
# Dict mapping citations keys to BibTex entries
self._all_citations: dict[str, str] = dict()
self._all_citations: dict = dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be dict[str, Entry], so this is a slight regression in type annotations


# dispatch
from .dispatch import parameter_sets as parameter_sets
from .dispatch import models as models
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not in the __init__.py file, could you please double check?

Comment on lines +5 to +12
# Lazily initialized posthog client
_posthog = None
_disabled = False


class MockTelemetry:
def __init__(self):
self.disabled = True
def _get_posthog():
"""Lazily initialize the posthog client on first use."""
global _posthog, _disabled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it would induce a race pretty easily. Two threads calling _get_posthog() simultaneously would both see _posthog is None and both attempt initialisation. This is the same for _jax_configured and get_jax. Could you add a lock here? We previously had an issue about something like this back when telemetry was recently implemented and we are working around it by PYBAMM_DISABLE_TELEMETRY. Also, it doesn't look like the thread safety tests are reaching these.

if platform != "metal":
assert jax.config.x64_enabled, "JAX x64 should be enabled after get_jax()"

@pytest.mark.skipif(not __import__("pybamm").has_jax(), reason="JAX not installed")
Copy link
Member

@agriyakhetarpal agriyakhetarpal Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and the rest of the JAX tests) can simply use pytest.importorskip(): https://docs.pytest.org/en/stable/reference/reference.html#pytest-importorskip

(otherwise we import pybamm and has_jax which will slow down the test collection)

Comment on lines +51 to +56
for module_path, attrs in EAGER_IMPORTS.items():
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make the order deterministic and that'll help reduce PR diffs a bit:

Suggested change
for module_path, attrs in EAGER_IMPORTS.items():
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")
for module_path in sorted(EAGER_IMPORTS):
attrs = EAGER_IMPORTS[module_path]
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")

Comment on lines +67 to +72
for module_path, attrs in LAZY_IMPORTS.items():
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here:

Suggested change
for module_path, attrs in LAZY_IMPORTS.items():
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")
for module_path in sorted(LAZY_IMPORTS):
attrs = LAZY_IMPORTS[module_path]
module_name = module_path.split(".")[-1]
lines.append(f"# {module_name}")
for attr in attrs:
lines.append(f"from {module_path} import {attr} as {attr}")
lines.append("")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants