diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml index d993a2796f..be7552b16d 100644 --- a/.github/workflows/test_on_push.yml +++ b/.github/workflows/test_on_push.yml @@ -40,6 +40,12 @@ jobs: - name: Check style run: uvx pre-commit run -a + - name: Check stub file is up to date + run: | + uv pip install -e . --system + python scripts/generate_pyi_stub.py --validate + python scripts/generate_pyi_stub.py --check + run_unit_tests: runs-on: ${{ matrix.os }} permissions: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19f1fc47ec..85caab816a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,8 +10,10 @@ repos: - id: ruff-check args: [--fix, --show-fixes] types_or: [python, pyi, jupyter] + exclude: ^src/pybamm/__init__\.pyi$ - id: ruff-format types_or: [python, pyi, jupyter] + exclude: ^src/pybamm/__init__\.pyi$ - repo: https://github.com/adamchainz/blacken-docs rev: "1.20.0" diff --git a/docs/source/api/models/submodels/interface/base_interface.rst b/docs/source/api/models/submodels/interface/base_interface.rst index bb6981dc7b..36013542ba 100644 --- a/docs/source/api/models/submodels/interface/base_interface.rst +++ b/docs/source/api/models/submodels/interface/base_interface.rst @@ -1,5 +1,5 @@ Interface Base Model ==================== -.. autoclass:: pybamm.interface.BaseInterface +.. autoclass:: pybamm.models.submodels.interface.BaseInterface :members: diff --git a/docs/source/api/models/submodels/interface/interface_utilisation/base_utilisation.rst b/docs/source/api/models/submodels/interface/interface_utilisation/base_utilisation.rst index b422b8a3cc..92f4aca982 100644 --- a/docs/source/api/models/submodels/interface/interface_utilisation/base_utilisation.rst +++ b/docs/source/api/models/submodels/interface/interface_utilisation/base_utilisation.rst @@ -1,5 +1,5 @@ Utilisation Base Model ====================== -.. autoclass:: pybamm.interface.interface_utilisation.BaseModel +.. autoclass:: pybamm.models.submodels.interface.interface_utilisation.BaseModel :members: diff --git a/docs/source/api/models/submodels/interface/interface_utilisation/constant_utilisation.rst b/docs/source/api/models/submodels/interface/interface_utilisation/constant_utilisation.rst index 5a3040e53b..9fb471874a 100644 --- a/docs/source/api/models/submodels/interface/interface_utilisation/constant_utilisation.rst +++ b/docs/source/api/models/submodels/interface/interface_utilisation/constant_utilisation.rst @@ -1,5 +1,5 @@ Constant Utilisation ==================== -.. autoclass:: pybamm.interface.interface_utilisation.Constant +.. autoclass:: pybamm.models.submodels.interface.interface_utilisation.Constant :members: diff --git a/docs/source/api/models/submodels/interface/interface_utilisation/current_driven_utilisation.rst b/docs/source/api/models/submodels/interface/interface_utilisation/current_driven_utilisation.rst index 0554aaeabd..d0910c3d9c 100644 --- a/docs/source/api/models/submodels/interface/interface_utilisation/current_driven_utilisation.rst +++ b/docs/source/api/models/submodels/interface/interface_utilisation/current_driven_utilisation.rst @@ -1,5 +1,5 @@ CurrentDriven Utilisation ========================= -.. autoclass:: pybamm.interface.interface_utilisation.CurrentDriven +.. autoclass:: pybamm.models.submodels.interface.interface_utilisation.CurrentDriven :members: diff --git a/docs/source/api/models/submodels/interface/interface_utilisation/full_utilisation.rst b/docs/source/api/models/submodels/interface/interface_utilisation/full_utilisation.rst index d1f3191d07..3a45a9cba7 100644 --- a/docs/source/api/models/submodels/interface/interface_utilisation/full_utilisation.rst +++ b/docs/source/api/models/submodels/interface/interface_utilisation/full_utilisation.rst @@ -1,5 +1,5 @@ Full Utilisation ================ -.. autoclass:: pybamm.interface.interface_utilisation.Full +.. autoclass:: pybamm.models.submodels.interface.interface_utilisation.Full :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/base_kinetics.rst b/docs/source/api/models/submodels/interface/kinetics/base_kinetics.rst index b17fb73a4d..4ae753d2cc 100644 --- a/docs/source/api/models/submodels/interface/kinetics/base_kinetics.rst +++ b/docs/source/api/models/submodels/interface/kinetics/base_kinetics.rst @@ -1,5 +1,5 @@ Base Kinetics ============= -.. autoclass:: pybamm.kinetics.BaseKinetics +.. autoclass:: pybamm.models.submodels.interface.kinetics.BaseKinetics :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/butler_volmer.rst b/docs/source/api/models/submodels/interface/kinetics/butler_volmer.rst index 522418a42f..7e257249a8 100644 --- a/docs/source/api/models/submodels/interface/kinetics/butler_volmer.rst +++ b/docs/source/api/models/submodels/interface/kinetics/butler_volmer.rst @@ -1,8 +1,8 @@ Butler Volmer ============= -.. autoclass:: pybamm.kinetics.SymmetricButlerVolmer +.. autoclass:: pybamm.models.submodels.interface.kinetics.SymmetricButlerVolmer :members: -.. autoclass:: pybamm.kinetics.AsymmetricButlerVolmer +.. autoclass:: pybamm.models.submodels.interface.kinetics.AsymmetricButlerVolmer :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/diffusion_limited.rst b/docs/source/api/models/submodels/interface/kinetics/diffusion_limited.rst index 3f2ba6d58e..fa1f89b241 100644 --- a/docs/source/api/models/submodels/interface/kinetics/diffusion_limited.rst +++ b/docs/source/api/models/submodels/interface/kinetics/diffusion_limited.rst @@ -1,5 +1,5 @@ Diffusion-limited ================= -.. autoclass:: pybamm.kinetics.DiffusionLimited +.. autoclass:: pybamm.models.submodels.interface.kinetics.DiffusionLimited :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/base_inverse.rst b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/base_inverse.rst index 19389c7e73..6e9f6e59d6 100644 --- a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/base_inverse.rst +++ b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/base_inverse.rst @@ -1,5 +1,5 @@ Base Inverse Kinetics ===================== -.. autoclass:: pybamm.kinetics.BaseInverseKinetics +.. autoclass:: pybamm.models.submodels.interface.kinetics.BaseInverseKinetics :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_butler_volmer.rst b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_butler_volmer.rst index 60bd3c7c0d..f1b318214c 100644 --- a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_butler_volmer.rst +++ b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_butler_volmer.rst @@ -1,5 +1,5 @@ Inverse Butler-Volmer ===================== -.. autoclass:: pybamm.kinetics.InverseButlerVolmer +.. autoclass:: pybamm.models.submodels.interface.kinetics.InverseButlerVolmer :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_linear.rst b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_linear.rst index 7218c1a8e2..a3a08f257c 100644 --- a/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_linear.rst +++ b/docs/source/api/models/submodels/interface/kinetics/inverse_kinetics/inverse_linear.rst @@ -1,5 +1,5 @@ Inverse Linear ============== -.. autoclass:: pybamm.kinetics.InverseLinear +.. autoclass:: pybamm.models.submodels.interface.kinetics.InverseLinear :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/linear.rst b/docs/source/api/models/submodels/interface/kinetics/linear.rst index 22c02cbe21..cb8ff4fa85 100644 --- a/docs/source/api/models/submodels/interface/kinetics/linear.rst +++ b/docs/source/api/models/submodels/interface/kinetics/linear.rst @@ -1,5 +1,5 @@ Linear ====== -.. autoclass:: pybamm.kinetics.Linear +.. autoclass:: pybamm.models.submodels.interface.kinetics.Linear :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/marcus.rst b/docs/source/api/models/submodels/interface/kinetics/marcus.rst index 6252f22e8b..caae0e17b6 100644 --- a/docs/source/api/models/submodels/interface/kinetics/marcus.rst +++ b/docs/source/api/models/submodels/interface/kinetics/marcus.rst @@ -1,5 +1,5 @@ Marcus ====== -.. autoclass:: pybamm.kinetics.Marcus +.. autoclass:: pybamm.models.submodels.interface.kinetics.Marcus :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/msmr_butler_volmer.rst b/docs/source/api/models/submodels/interface/kinetics/msmr_butler_volmer.rst index 18bea7ee7a..7cf33de2fe 100644 --- a/docs/source/api/models/submodels/interface/kinetics/msmr_butler_volmer.rst +++ b/docs/source/api/models/submodels/interface/kinetics/msmr_butler_volmer.rst @@ -1,5 +1,5 @@ MSMR Butler Volmer ================== -.. autoclass:: pybamm.kinetics.MSMRButlerVolmer +.. autoclass:: pybamm.models.submodels.interface.kinetics.MSMRButlerVolmer :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/no_reaction.rst b/docs/source/api/models/submodels/interface/kinetics/no_reaction.rst index 573b9a2117..dab7a0b513 100644 --- a/docs/source/api/models/submodels/interface/kinetics/no_reaction.rst +++ b/docs/source/api/models/submodels/interface/kinetics/no_reaction.rst @@ -1,5 +1,5 @@ NoReaction ========== -.. autoclass:: pybamm.kinetics.NoReaction +.. autoclass:: pybamm.models.submodels.interface.kinetics.NoReaction :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/tafel.rst b/docs/source/api/models/submodels/interface/kinetics/tafel.rst index 1969ab5b53..f4bdce03ec 100644 --- a/docs/source/api/models/submodels/interface/kinetics/tafel.rst +++ b/docs/source/api/models/submodels/interface/kinetics/tafel.rst @@ -1,5 +1,5 @@ Tafel ===== -.. autoclass:: pybamm.kinetics.ForwardTafel +.. autoclass:: pybamm.models.submodels.interface.kinetics.ForwardTafel :members: diff --git a/docs/source/api/models/submodels/interface/kinetics/total_main_kinetics.rst b/docs/source/api/models/submodels/interface/kinetics/total_main_kinetics.rst index 36b17e2564..5b7fddceca 100644 --- a/docs/source/api/models/submodels/interface/kinetics/total_main_kinetics.rst +++ b/docs/source/api/models/submodels/interface/kinetics/total_main_kinetics.rst @@ -2,5 +2,5 @@ Total Main Kinetics =================== -.. autoclass:: pybamm.kinetics.TotalMainKinetics +.. autoclass:: pybamm.models.submodels.interface.kinetics.TotalMainKinetics :members: diff --git a/docs/source/api/models/submodels/interface/lithium_plating/base_plating.rst b/docs/source/api/models/submodels/interface/lithium_plating/base_plating.rst index 612bc49002..b8ff363e28 100644 --- a/docs/source/api/models/submodels/interface/lithium_plating/base_plating.rst +++ b/docs/source/api/models/submodels/interface/lithium_plating/base_plating.rst @@ -1,7 +1,7 @@ Base Plating ============ -.. autoclass:: pybamm.lithium_plating.BasePlating +.. autoclass:: pybamm.models.submodels.interface.lithium_plating.BasePlating :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/lithium_plating/no_plating.rst b/docs/source/api/models/submodels/interface/lithium_plating/no_plating.rst index 97cb027a95..9ef8c83b65 100644 --- a/docs/source/api/models/submodels/interface/lithium_plating/no_plating.rst +++ b/docs/source/api/models/submodels/interface/lithium_plating/no_plating.rst @@ -1,5 +1,5 @@ No Plating ========== -.. autoclass:: pybamm.lithium_plating.NoPlating +.. autoclass:: pybamm.models.submodels.interface.lithium_plating.NoPlating :members: diff --git a/docs/source/api/models/submodels/interface/lithium_plating/plating.rst b/docs/source/api/models/submodels/interface/lithium_plating/plating.rst index a358f0b10f..3e7daa5e52 100644 --- a/docs/source/api/models/submodels/interface/lithium_plating/plating.rst +++ b/docs/source/api/models/submodels/interface/lithium_plating/plating.rst @@ -1,7 +1,7 @@ Plating ======= -.. autoclass:: pybamm.lithium_plating.Plating +.. autoclass:: pybamm.models.submodels.interface.lithium_plating.Plating :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/base_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/base_ocp.rst index bda54d9c2d..38c1cd0294 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/base_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/base_ocp.rst @@ -1,8 +1,8 @@ Base Open Circuit Potential =========================== -.. autoclass:: pybamm.open_circuit_potential.BaseOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.BaseOpenCircuitPotential :members: -.. autoclass:: pybamm.open_circuit_potential.BaseHysteresisOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.BaseHysteresisOpenCircuitPotential :members: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/current_sigmoid_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/current_sigmoid_ocp.rst index 832cc57603..a0cc109a2f 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/current_sigmoid_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/current_sigmoid_ocp.rst @@ -1,5 +1,5 @@ Current Sigmoid Open Circuit Potential ====================================== -.. autoclass:: pybamm.open_circuit_potential.CurrentSigmoidOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.CurrentSigmoidOpenCircuitPotential :members: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/msmr_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/msmr_ocp.rst index f2106367d2..3874c88df7 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/msmr_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/msmr_ocp.rst @@ -2,7 +2,7 @@ MSMR Open Circuit Potential =========================== -.. autoclass:: pybamm.open_circuit_potential.MSMROpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.MSMROpenCircuitPotential :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_differential_capacity_hysteresis_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_differential_capacity_hysteresis_ocp.rst index 6a4b90ddd4..8f247e1ca5 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_differential_capacity_hysteresis_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_differential_capacity_hysteresis_ocp.rst @@ -1,7 +1,7 @@ One State Differential Capacity Hysteresis Open Circuit Potential ================================================================= -.. autoclass:: pybamm.open_circuit_potential.OneStateDifferentialCapacityHysteresisOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.OneStateDifferentialCapacityHysteresisOpenCircuitPotential :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_hysteresis_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_hysteresis_ocp.rst index 3427045731..9468c4b382 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_hysteresis_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/one_state_hysteresis_ocp.rst @@ -1,7 +1,7 @@ One-state hysteresis open-circuit potential =========================================== -.. autoclass:: pybamm.open_circuit_potential.OneStateHysteresisOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.OneStateHysteresisOpenCircuitPotential :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/open_circuit_potential/single_ocp.rst b/docs/source/api/models/submodels/interface/open_circuit_potential/single_ocp.rst index b5ac5b0e8e..311d6729d5 100644 --- a/docs/source/api/models/submodels/interface/open_circuit_potential/single_ocp.rst +++ b/docs/source/api/models/submodels/interface/open_circuit_potential/single_ocp.rst @@ -2,5 +2,5 @@ Single Open Circuit Potential ============================= -.. autoclass:: pybamm.open_circuit_potential.SingleOpenCircuitPotential +.. autoclass:: pybamm.models.submodels.interface.open_circuit_potential.SingleOpenCircuitPotential :members: diff --git a/docs/source/api/models/submodels/interface/sei/base_sei.rst b/docs/source/api/models/submodels/interface/sei/base_sei.rst index 2e42e0eea9..1962f41de7 100644 --- a/docs/source/api/models/submodels/interface/sei/base_sei.rst +++ b/docs/source/api/models/submodels/interface/sei/base_sei.rst @@ -1,5 +1,5 @@ SEI Base Model ============== -.. autoclass:: pybamm.sei.BaseModel +.. autoclass:: pybamm.models.submodels.interface.sei.BaseModel :members: diff --git a/docs/source/api/models/submodels/interface/sei/constant_sei.rst b/docs/source/api/models/submodels/interface/sei/constant_sei.rst index a38652daa8..201a7b70c4 100644 --- a/docs/source/api/models/submodels/interface/sei/constant_sei.rst +++ b/docs/source/api/models/submodels/interface/sei/constant_sei.rst @@ -1,5 +1,5 @@ Constant SEI ============ -.. autoclass:: pybamm.sei.ConstantSEI +.. autoclass:: pybamm.models.submodels.interface.sei.ConstantSEI :members: diff --git a/docs/source/api/models/submodels/interface/sei/no_sei.rst b/docs/source/api/models/submodels/interface/sei/no_sei.rst index 69e93f5d2e..2afe1bc07b 100644 --- a/docs/source/api/models/submodels/interface/sei/no_sei.rst +++ b/docs/source/api/models/submodels/interface/sei/no_sei.rst @@ -1,5 +1,5 @@ No SEI ====== -.. autoclass:: pybamm.sei.NoSEI +.. autoclass:: pybamm.models.submodels.interface.sei.NoSEI :members: diff --git a/docs/source/api/models/submodels/interface/sei/sei_growth.rst b/docs/source/api/models/submodels/interface/sei/sei_growth.rst index d43cfade93..87837d29e0 100644 --- a/docs/source/api/models/submodels/interface/sei/sei_growth.rst +++ b/docs/source/api/models/submodels/interface/sei/sei_growth.rst @@ -1,7 +1,7 @@ SEI Growth ========== -.. autoclass:: pybamm.sei.SEIGrowth +.. autoclass:: pybamm.models.submodels.interface.sei.SEIGrowth :members: .. footbibliography:: diff --git a/docs/source/api/models/submodels/interface/sei/total_sei.rst b/docs/source/api/models/submodels/interface/sei/total_sei.rst index 4210bcd47f..5af9425773 100644 --- a/docs/source/api/models/submodels/interface/sei/total_sei.rst +++ b/docs/source/api/models/submodels/interface/sei/total_sei.rst @@ -1,5 +1,5 @@ Total SEI ========= -.. autoclass:: pybamm.sei.TotalSEI +.. autoclass:: pybamm.models.submodels.interface.sei.TotalSEI :members: diff --git a/docs/source/api/models/submodels/interface/total_interfacial_current.rst b/docs/source/api/models/submodels/interface/total_interfacial_current.rst index abb65459b5..560d24055d 100644 --- a/docs/source/api/models/submodels/interface/total_interfacial_current.rst +++ b/docs/source/api/models/submodels/interface/total_interfacial_current.rst @@ -1,5 +1,5 @@ Total Interfacial Current Model =============================== -.. autoclass:: pybamm.interface.TotalInterfacialCurrent +.. autoclass:: pybamm.models.submodels.interface.TotalInterfacialCurrent :members: diff --git a/pyproject.toml b/pyproject.toml index 5a503ecfc8..c3e6ae759b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "posthog", "pyyaml", "platformdirs", + "lazy_loader>=0.4", ] [project.urls] diff --git a/scripts/generate_pyi_stub.py b/scripts/generate_pyi_stub.py new file mode 100755 index 0000000000..ded922c0a1 --- /dev/null +++ b/scripts/generate_pyi_stub.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +""" +Generate the __init__.pyi stub file for PyBaMM. + +This script generates the stub file that enables: +1. IDE autocomplete for lazy-loaded attributes +2. Type checking support +3. lazy_loader integration for lazy imports + +Usage: + python scripts/generate_pyi_stub.py [options] + +Options: + --check Don't write the file, just check if it would change (for CI) + --validate Validate that all configured imports exist in the modules + +Examples: + python scripts/generate_pyi_stub.py # Generate/update stub + python scripts/generate_pyi_stub.py --validate # Validate imports exist + python scripts/generate_pyi_stub.py --check # CI check (exit 1 if outdated) +""" + +from __future__ import annotations + +import argparse +import importlib +import sys +from pathlib import Path + +from pybamm._lazy_config import EAGER_IMPORTS, LAZY_IMPORTS, SUBMODULE_ALIASES + +# Add src to path for imports +REPO_ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(REPO_ROOT / "src")) + + +def generate_stub_content() -> str: + """Generate the complete stub file content.""" + lines = [ + "# PyBaMM stub file for IDE support and type hints", + "# Auto-generated by scripts/generate_pyi_stub.py - do not edit manually", + "", + "# Version", + "from .version import __version__ as __version__", + "", + "# EAGERLY LOADED (imported at module load time)", + "", + ] + + # Eager imports + for module_path, attrs in EAGER_IMPORTS.items(): + module_name = module_path.split(".")[-1] + lines.append(f"# {module_name}") + for attr in attrs: + lines.append(f"from {module_path} import {attr} as {attr}") + lines.append("") + + # Config module + lines.append("# Config module") + lines.append("from . import config as config") + lines.append("") + + # Lazy imports + lines.append("# LAZILY LOADED (via lazy_loader stub mechanism)") + lines.append("") + + for module_path, attrs in LAZY_IMPORTS.items(): + module_name = module_path.split(".")[-1] + lines.append(f"# {module_name}") + for attr in attrs: + lines.append(f"from {module_path} import {attr} as {attr}") + lines.append("") + + # Submodule aliases + lines.append("# SUBMODULE ALIASES") + lines.append("") + + for alias, path in sorted(SUBMODULE_ALIASES.items()): + parts = path.rsplit(".", 1) + if len(parts) == 2 and parts[0]: + parent_path, module_name = parts + lines.append(f"from {parent_path} import {module_name} as {alias}") + else: + module_name = path.lstrip(".") + lines.append(f"from . import {module_name} as {alias}") + + return "\n".join(lines) + "\n" + + +def validate_imports() -> list[str]: + """Validate that all configured imports actually exist.""" + errors = [] + + for label, imports in [("EAGER", EAGER_IMPORTS), ("LAZY", LAZY_IMPORTS)]: + for module_path, attrs in imports.items(): + try: + module = importlib.import_module(f"pybamm{module_path}") + for attr in attrs: + if not hasattr(module, attr): + errors.append(f"{label}: {module_path}.{attr} not found") + except ImportError as e: + errors.append(f"{label}: Cannot import {module_path}: {e}") + + for alias, module_path in SUBMODULE_ALIASES.items(): + try: + importlib.import_module(f"pybamm{module_path}") + except ImportError as e: + errors.append(f"SUBMODULE: {alias} -> {module_path}: {e}") + + return errors + + +def main(): + parser = argparse.ArgumentParser( + description="Generate __init__.pyi stub file for PyBaMM" + ) + parser.add_argument( + "--check", + action="store_true", + help="Check if stub would change (for CI), exit 1 if different", + ) + parser.add_argument( + "--validate", + action="store_true", + help="Validate that all configured imports exist", + ) + args = parser.parse_args() + + stub_path = REPO_ROOT / "src" / "pybamm" / "__init__.pyi" + + if args.validate: + print("Validating imports...") + errors = validate_imports() + if errors: + print("Validation errors:") + for error in errors: + print(f" - {error}") + sys.exit(1) + print("All imports validated successfully!") + return + + content = generate_stub_content() + + if args.check: + if stub_path.exists(): + existing = stub_path.read_text() + if existing == content: + print("Stub file is up to date.") + sys.exit(0) + else: + print( + "Stub file is out of date. Run 'python scripts/generate_pyi_stub.py' to update." + ) + sys.exit(1) + else: + print("Stub file does not exist.") + sys.exit(1) + + stub_path.write_text(content) + print(f"Generated {stub_path}") + + eager_count = sum(len(attrs) for attrs in EAGER_IMPORTS.values()) + lazy_count = sum(len(attrs) for attrs in LAZY_IMPORTS.values()) + submodule_count = len(SUBMODULE_ALIASES) + print(f" - {eager_count} eagerly loaded exports") + print(f" - {lazy_count} lazily loaded exports") + print(f" - {submodule_count} submodule aliases") + print(f" - Total: {eager_count + lazy_count + submodule_count} exports") + + +if __name__ == "__main__": + main() diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 734a710aca..7c25399330 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -1,256 +1,104 @@ +# Lazy loading implementation using lazy_loader package +# Essential imports only - everything else is lazily loaded via stub file +# +# The stub file (__init__.pyi) is auto-generated. To regenerate: +# python scripts/generate_pyi_stub.py +# +# To validate all imports are correct: +# python scripts/generate_pyi_stub.py --validate +# +# For CI checks (exits non-zero if stub is outdated): +# python scripts/generate_pyi_stub.py --check + +import importlib + +import lazy_loader as lazy + from pybamm.version import __version__ -# Utility classes and methods -from .util import root_dir -from .util import Timer, TimerTime, FuzzyDict -from .util import ( - root_dir, - load, - is_constant_and_can_evaluate, -) -from .util import ( - get_parameters_filepath, - has_jax, - import_optional_dependency, -) +# Core utilities that are lightweight and commonly needed from .logger import logger, set_logging_level, get_new_logger from .settings import settings -from .citations import Citations, citations, print_citations from . import config -# Classes for the Expression Tree +# These need to be imported eagerly to shadow the submodule names +from .citations import Citations, citations, print_citations + +# Expression tree modules are accessed thousands of times during model building +# so we eagerly import them to avoid __getattr__ overhead from .expression_tree.symbol import * from .expression_tree.binary_operators import * from .expression_tree.concatenations import * -from .expression_tree.array import Array, linspace, meshgrid -from .expression_tree.matrix import Matrix from .expression_tree.unary_operators import * from .expression_tree.averages import * from .expression_tree.averages import _BaseAverage from .expression_tree.broadcasts import * from .expression_tree.functions import * -from .expression_tree.interpolant import Interpolant +from .expression_tree.interpolant import Interpolant # needed before discrete_time_sum from .expression_tree.discrete_time_sum import * -from .expression_tree.input_parameter import InputParameter -from .expression_tree.parameter import Parameter, FunctionParameter -from .expression_tree.scalar import Scalar, Constant from .expression_tree.variable import * from .expression_tree.coupled_variable import * from .expression_tree.independent_variable import * -from .expression_tree.independent_variable import t -from .expression_tree.vector import Vector -from .expression_tree.tensor_field import TensorField -from .expression_tree.vector_field import VectorField -from .expression_tree.state_vector import StateVectorBase, StateVector, StateVectorDot - from .expression_tree.exceptions import * +from .expression_tree.scalar import * +from .expression_tree.state_vector import * +from .expression_tree.tensor_field import * +from .expression_tree.parameter import * +from .expression_tree.input_parameter import * +from .expression_tree.array import * +from .expression_tree.vector_field import * +from .expression_tree.matrix import * +from .expression_tree.vector import * -# Operations -from .expression_tree.operations.evaluate_python import ( - find_symbols, - id_to_python_variable, - to_python, - EvaluatorPython, -) - -from .expression_tree.operations.evaluate_python import EvaluatorJax -from .expression_tree.operations.evaluate_python import JaxCooMatrix - -from .expression_tree.operations.jacobian import Jacobian -from .expression_tree.operations.convert_to_casadi import CasadiConverter -from .expression_tree.operations.unpack_symbols import SymbolUnpacker -from .expression_tree.operations.serialise import Serialise,ExpressionFunctionParameter - -# Model classes -from .models.base_model import BaseModel, ModelSolutionObservability -from .models.symbol_processor import SymbolProcessor -from .models.event import Event -from .models.event import EventType - -# Battery models -from .models.full_battery_models.base_battery_model import ( - BaseBatteryModel, - BatteryModelOptions, -) -from .models.full_battery_models import lead_acid -from .models.full_battery_models import lithium_ion -from .models.full_battery_models import equivalent_circuit -from .models.full_battery_models import sodium_ion - -# Submodel classes -from .models.submodels.base_submodel import BaseSubModel +# Lazy loading via stub file - get the base __getattr__ and __dir__ +_lazy_getattr, _lazy_dir, _stub_all = lazy.attach_stub(__name__, __file__) -from .models.submodels import ( - active_material, - convection, - current_collector, - electrolyte_conductivity, - electrolyte_diffusion, - electrode, - external_circuit, - interface, - oxygen_diffusion, - particle, - porosity, - thermal, - transport_efficiency, - particle_mechanics, - equivalent_circuit_elements, -) -from .models.submodels.interface import kinetics -from .models.submodels.interface import sei -from .models.submodels.interface import lithium_plating -from .models.submodels.interface import interface_utilisation -from .models.submodels.interface import open_circuit_potential +# These are submodules that we want to expose at the top level (e.g., pybamm.lithium_ion) +from ._lazy_config import SUBMODULE_ALIASES as _SUBMODULE_ALIASES -# Geometry -from .geometry.geometry import Geometry -from .geometry.battery_geometry import battery_geometry +# Cache for loaded submodule aliases +_loaded_submodules: dict[str, object] = {} -from .expression_tree.independent_variable import KNOWN_COORD_SYS -from .geometry import standard_spatial_vars -# Parameter classes and methods -from .parameters.parameter_values import ParameterValues, scalarize_dict, arrayize_dict -from .parameters import constants -from .parameters.geometric_parameters import geometric_parameters, GeometricParameters -from .parameters.electrical_parameters import ( - electrical_parameters, - ElectricalParameters, -) -from .parameters.thermal_parameters import thermal_parameters, ThermalParameters -from .parameters.lithium_ion_parameters import LithiumIonParameters -from .parameters.lead_acid_parameters import LeadAcidParameters -from .parameters.ecm_parameters import EcmParameters -from .parameters.size_distribution_parameters import * +def __getattr__(name: str) -> object: + """Custom __getattr__ that handles submodule aliases and falls back to lazy_loader.""" + # Fast path: check submodule cache + if name in _loaded_submodules: + return _loaded_submodules[name] -# Mesh and Discretisation classes -from .discretisations.discretisation import Discretisation -from .discretisations.discretisation import has_bc_of_form -from .meshes.meshes import Mesh, SubMesh, MeshGenerator -from .meshes.zero_dimensional_submesh import SubMesh0D -from .meshes.one_dimensional_submeshes import ( - SubMesh1D, - Uniform1DSubMesh, - Exponential1DSubMesh, - Chebyshev1DSubMesh, - UserSupplied1DSubMesh, - SpectralVolume1DSubMesh, - SymbolicUniform1DSubMesh, -) -from .meshes.two_dimensional_submeshes import ( - SubMesh2D, - Uniform2DSubMesh, -) -from .meshes.scikit_fem_submeshes import ( - ScikitSubMesh2D, - ScikitUniform2DSubMesh, - ScikitExponential2DSubMesh, - ScikitChebyshev2DSubMesh, - UserSupplied2DSubMesh, -) + # Check if it's a submodule alias + if name in _SUBMODULE_ALIASES: + module_path = _SUBMODULE_ALIASES[name] + module = importlib.import_module(module_path, package="pybamm") + _loaded_submodules[name] = module + return module -from .meshes.scikit_fem_submeshes_3d import ( - ScikitFemSubMesh3D, - ScikitFemGenerator3D, - UserSuppliedSubmesh3D, -) + # Fall back to lazy_loader's __getattr__ + return _lazy_getattr(name) -# Serialisation -from .models.base_model import load_model -# Spatial Methods -from .spatial_methods.spatial_method import SpatialMethod -from .spatial_methods.zero_dimensional_method import ZeroDimensionalSpatialMethod -from .spatial_methods.finite_volume import FiniteVolume -from .spatial_methods.finite_volume_2d import FiniteVolume2D -from .spatial_methods.spectral_volume import SpectralVolume -from .spatial_methods.scikit_finite_element import ScikitFiniteElement -from .spatial_methods.scikit_finite_element_3d import ScikitFiniteElement3D +def __dir__() -> list[str]: + """Include submodule aliases in dir() output.""" + return sorted(set(_lazy_dir()) | set(_SUBMODULE_ALIASES.keys())) -# Solver classes -from .solvers.solution import Solution, EmptySolution, make_cycle_solution -from .solvers.processed_variable_time_integral import ProcessedVariableTimeIntegral -from .solvers.processed_variable import ProcessedVariable, ProcessedVariable2DFVM, process_variable -from .solvers.processed_variable_computed import ProcessedVariableComputed -from .solvers.processed_variable import ProcessedVariableUnstructured -from .solvers.summary_variable import SummaryVariables -from .solvers.base_solver import BaseSolver -from .solvers.dummy_solver import DummySolver -from .solvers.algebraic_solver import AlgebraicSolver -from .solvers.casadi_solver import CasadiSolver -from .solvers.casadi_algebraic_solver import CasadiAlgebraicSolver -from .solvers.scipy_solver import ScipySolver -from .solvers.composite_solver import CompositeSolver -from .solvers.jax_solver import JaxSolver -from .solvers.jax_bdf_solver import jax_bdf_integrate +# Extend __all__ to include submodule aliases +__all__ = _stub_all + list(_SUBMODULE_ALIASES.keys()) -from .solvers.idaklu_jax import IDAKLUJax -from .solvers.idaklu_solver import IDAKLUSolver -# Experiments -from .experiment.experiment import Experiment -from . import experiment -from .experiment import step - -# Plotting -from .plotting.quick_plot import QuickPlot, close_plots, QuickPlotAxes -from .plotting.plot import plot -from .plotting.plot2D import plot2D -from .plotting.plot_voltage_components import plot_voltage_components -from .plotting.plot_thermal_components import plot_thermal_components -from .plotting.plot_summary_variables import plot_summary_variables -from .plotting.dynamic_plot import dynamic_plot -from .plotting.plot_3d_cross_section import plot_3d_cross_section -from .plotting.plot_3d_heatmap import plot_3d_heatmap - -# Simulation -from .simulation import Simulation, load_sim, is_notebook - -# Batch Study -from .batch_study import BatchStudy - -# Callbacks, telemetry, config -from . import callbacks, telemetry, config - -# Pybamm Data manager using pooch -from .pybamm_data import DataLoader - -from .dispatch import parameter_sets, Model - -# Fix Casadi import +# Fix Casadi import - this needs to happen at import time import os import pathlib import sysconfig os.environ["CASADIPATH"] = str(pathlib.Path(sysconfig.get_path("purelib")) / "casadi") -__all__ = [ - "batch_study", - "callbacks", - "citations", - "config", - "discretisations", - "experiment", - "expression_tree", - "geometry", - "input", - "logger", - "meshes", - "models", - "parameters", - "plotting", - "settings", - "simulation", - "solvers", - "spatial_methods", - "telemetry", - "type_definitions", - "util", - "version", - "pybamm_data", - "dispatch", -] - config.generate() + +# Eagerly load core simulation modules to optimize first-solve performance +from . import simulation # noqa: F401, E402 +from . import solvers # noqa: F401, E402 +from .parameters import parameter_values # noqa: F401, E402 +from . import meshes # noqa: F401, E402 +from . import spatial_methods # noqa: F401, E402 +from .expression_tree.operations import convert_to_casadi # noqa: F401, E402 diff --git a/src/pybamm/__init__.pyi b/src/pybamm/__init__.pyi new file mode 100644 index 0000000000..777f4a41db --- /dev/null +++ b/src/pybamm/__init__.pyi @@ -0,0 +1,524 @@ +# PyBaMM stub file for IDE support and type hints +# Auto-generated by scripts/generate_pyi_stub.py - do not edit manually + +# Version +from .version import __version__ as __version__ + +# EAGERLY LOADED (imported at module load time) + +# logger +from .logger import logger as logger +from .logger import set_logging_level as set_logging_level +from .logger import get_new_logger as get_new_logger + +# settings +from .settings import settings as settings + +# citations +from .citations import Citations as Citations +from .citations import citations as citations +from .citations import print_citations as print_citations + +# symbol +from .expression_tree.symbol import Symbol as Symbol +from .expression_tree.symbol import domain_size as domain_size +from .expression_tree.symbol import create_object_of_size as create_object_of_size +from .expression_tree.symbol import evaluate_for_shape_using_domain as evaluate_for_shape_using_domain +from .expression_tree.symbol import is_constant as is_constant +from .expression_tree.symbol import is_scalar_zero as is_scalar_zero +from .expression_tree.symbol import is_scalar_one as is_scalar_one +from .expression_tree.symbol import is_scalar_minus_one as is_scalar_minus_one +from .expression_tree.symbol import is_matrix_zero as is_matrix_zero +from .expression_tree.symbol import is_matrix_one as is_matrix_one +from .expression_tree.symbol import is_matrix_minus_one as is_matrix_minus_one +from .expression_tree.symbol import simplify_if_constant as simplify_if_constant +from .expression_tree.symbol import convert_to_symbol as convert_to_symbol + +# binary_operators +from .expression_tree.binary_operators import BinaryOperator as BinaryOperator +from .expression_tree.binary_operators import Power as Power +from .expression_tree.binary_operators import Addition as Addition +from .expression_tree.binary_operators import Subtraction as Subtraction +from .expression_tree.binary_operators import Multiplication as Multiplication +from .expression_tree.binary_operators import KroneckerProduct as KroneckerProduct +from .expression_tree.binary_operators import TensorProduct as TensorProduct +from .expression_tree.binary_operators import MatrixMultiplication as MatrixMultiplication +from .expression_tree.binary_operators import Division as Division +from .expression_tree.binary_operators import Inner as Inner +from .expression_tree.binary_operators import Equality as Equality +from .expression_tree.binary_operators import EqualHeaviside as EqualHeaviside +from .expression_tree.binary_operators import NotEqualHeaviside as NotEqualHeaviside +from .expression_tree.binary_operators import Modulo as Modulo +from .expression_tree.binary_operators import Minimum as Minimum +from .expression_tree.binary_operators import Maximum as Maximum +from .expression_tree.binary_operators import softplus as softplus +from .expression_tree.binary_operators import softminus as softminus +from .expression_tree.binary_operators import sigmoid as sigmoid +from .expression_tree.binary_operators import source as source + +# concatenations +from .expression_tree.concatenations import Concatenation as Concatenation +from .expression_tree.concatenations import NumpyConcatenation as NumpyConcatenation +from .expression_tree.concatenations import DomainConcatenation as DomainConcatenation +from .expression_tree.concatenations import SparseStack as SparseStack +from .expression_tree.concatenations import ConcatenationVariable as ConcatenationVariable +from .expression_tree.concatenations import concatenation as concatenation +from .expression_tree.concatenations import numpy_concatenation as numpy_concatenation + +# unary_operators +from .expression_tree.unary_operators import UnaryOperator as UnaryOperator +from .expression_tree.unary_operators import Negate as Negate +from .expression_tree.unary_operators import AbsoluteValue as AbsoluteValue +from .expression_tree.unary_operators import Transpose as Transpose +from .expression_tree.unary_operators import Sign as Sign +from .expression_tree.unary_operators import Floor as Floor +from .expression_tree.unary_operators import Ceiling as Ceiling +from .expression_tree.unary_operators import Index as Index +from .expression_tree.unary_operators import SpatialOperator as SpatialOperator +from .expression_tree.unary_operators import Gradient as Gradient +from .expression_tree.unary_operators import Divergence as Divergence +from .expression_tree.unary_operators import Laplacian as Laplacian +from .expression_tree.unary_operators import GradientSquared as GradientSquared +from .expression_tree.unary_operators import Mass as Mass +from .expression_tree.unary_operators import BoundaryMass as BoundaryMass +from .expression_tree.unary_operators import Integral as Integral +from .expression_tree.unary_operators import BaseIndefiniteIntegral as BaseIndefiniteIntegral +from .expression_tree.unary_operators import IndefiniteIntegral as IndefiniteIntegral +from .expression_tree.unary_operators import BackwardIndefiniteIntegral as BackwardIndefiniteIntegral +from .expression_tree.unary_operators import DefiniteIntegralVector as DefiniteIntegralVector +from .expression_tree.unary_operators import BoundaryIntegral as BoundaryIntegral +from .expression_tree.unary_operators import OneDimensionalIntegral as OneDimensionalIntegral +from .expression_tree.unary_operators import DeltaFunction as DeltaFunction +from .expression_tree.unary_operators import BoundaryOperator as BoundaryOperator +from .expression_tree.unary_operators import BoundaryValue as BoundaryValue +from .expression_tree.unary_operators import BoundaryMeshSize as BoundaryMeshSize +from .expression_tree.unary_operators import ExplicitTimeIntegral as ExplicitTimeIntegral +from .expression_tree.unary_operators import BoundaryGradient as BoundaryGradient +from .expression_tree.unary_operators import EvaluateAt as EvaluateAt +from .expression_tree.unary_operators import UpwindDownwind as UpwindDownwind +from .expression_tree.unary_operators import UpwindDownwind2D as UpwindDownwind2D +from .expression_tree.unary_operators import NodeToEdge2D as NodeToEdge2D +from .expression_tree.unary_operators import Magnitude as Magnitude +from .expression_tree.unary_operators import Upwind as Upwind +from .expression_tree.unary_operators import Downwind as Downwind +from .expression_tree.unary_operators import NotConstant as NotConstant +from .expression_tree.unary_operators import grad as grad +from .expression_tree.unary_operators import div as div +from .expression_tree.unary_operators import laplacian as laplacian +from .expression_tree.unary_operators import grad_squared as grad_squared +from .expression_tree.unary_operators import surf as surf +from .expression_tree.unary_operators import boundary_value as boundary_value +from .expression_tree.unary_operators import sign as sign +from .expression_tree.unary_operators import upwind as upwind +from .expression_tree.unary_operators import downwind as downwind + +# averages +from .expression_tree.averages import _BaseAverage as _BaseAverage +from .expression_tree.averages import XAverage as XAverage +from .expression_tree.averages import YZAverage as YZAverage +from .expression_tree.averages import ZAverage as ZAverage +from .expression_tree.averages import RAverage as RAverage +from .expression_tree.averages import SizeAverage as SizeAverage +from .expression_tree.averages import x_average as x_average +from .expression_tree.averages import yz_average as yz_average +from .expression_tree.averages import z_average as z_average +from .expression_tree.averages import r_average as r_average +from .expression_tree.averages import size_average as size_average + +# broadcasts +from .expression_tree.broadcasts import Broadcast as Broadcast +from .expression_tree.broadcasts import PrimaryBroadcast as PrimaryBroadcast +from .expression_tree.broadcasts import PrimaryBroadcastToEdges as PrimaryBroadcastToEdges +from .expression_tree.broadcasts import SecondaryBroadcast as SecondaryBroadcast +from .expression_tree.broadcasts import SecondaryBroadcastToEdges as SecondaryBroadcastToEdges +from .expression_tree.broadcasts import TertiaryBroadcast as TertiaryBroadcast +from .expression_tree.broadcasts import TertiaryBroadcastToEdges as TertiaryBroadcastToEdges +from .expression_tree.broadcasts import FullBroadcast as FullBroadcast +from .expression_tree.broadcasts import FullBroadcastToEdges as FullBroadcastToEdges +from .expression_tree.broadcasts import ones_like as ones_like +from .expression_tree.broadcasts import zeros_like as zeros_like + +# functions +from .expression_tree.functions import Function as Function +from .expression_tree.functions import SpecificFunction as SpecificFunction +from .expression_tree.functions import Arcsinh as Arcsinh +from .expression_tree.functions import Arctan as Arctan +from .expression_tree.functions import Cos as Cos +from .expression_tree.functions import Cosh as Cosh +from .expression_tree.functions import Erf as Erf +from .expression_tree.functions import Exp as Exp +from .expression_tree.functions import Log as Log +from .expression_tree.functions import Max as Max +from .expression_tree.functions import Min as Min +from .expression_tree.functions import Sin as Sin +from .expression_tree.functions import Sinh as Sinh +from .expression_tree.functions import Sqrt as Sqrt +from .expression_tree.functions import Tanh as Tanh +from .expression_tree.functions import arcsinh as arcsinh +from .expression_tree.functions import arctan as arctan +from .expression_tree.functions import cos as cos +from .expression_tree.functions import cosh as cosh +from .expression_tree.functions import erf as erf +from .expression_tree.functions import exp as exp +from .expression_tree.functions import log as log +from .expression_tree.functions import log10 as log10 +from .expression_tree.functions import max as max +from .expression_tree.functions import min as min +from .expression_tree.functions import sin as sin +from .expression_tree.functions import sinh as sinh +from .expression_tree.functions import sqrt as sqrt +from .expression_tree.functions import tanh as tanh + +# interpolant +from .expression_tree.interpolant import Interpolant as Interpolant + +# discrete_time_sum +from .expression_tree.discrete_time_sum import DiscreteTimeData as DiscreteTimeData +from .expression_tree.discrete_time_sum import DiscreteTimeSum as DiscreteTimeSum + +# variable +from .expression_tree.variable import VariableBase as VariableBase +from .expression_tree.variable import Variable as Variable +from .expression_tree.variable import VariableDot as VariableDot + +# coupled_variable +from .expression_tree.coupled_variable import CoupledVariable as CoupledVariable + +# independent_variable +from .expression_tree.independent_variable import IndependentVariable as IndependentVariable +from .expression_tree.independent_variable import Time as Time +from .expression_tree.independent_variable import SpatialVariable as SpatialVariable +from .expression_tree.independent_variable import SpatialVariableEdge as SpatialVariableEdge +from .expression_tree.independent_variable import t as t +from .expression_tree.independent_variable import KNOWN_COORD_SYS as KNOWN_COORD_SYS + +# exceptions +from .expression_tree.exceptions import DomainError as DomainError +from .expression_tree.exceptions import OptionError as OptionError +from .expression_tree.exceptions import OptionWarning as OptionWarning +from .expression_tree.exceptions import GeometryError as GeometryError +from .expression_tree.exceptions import ModelError as ModelError +from .expression_tree.exceptions import SolverError as SolverError +from .expression_tree.exceptions import SolverWarning as SolverWarning +from .expression_tree.exceptions import ShapeError as ShapeError +from .expression_tree.exceptions import ModelWarning as ModelWarning +from .expression_tree.exceptions import DiscretisationError as DiscretisationError +from .expression_tree.exceptions import InvalidModelJSONError as InvalidModelJSONError + +# scalar +from .expression_tree.scalar import Scalar as Scalar +from .expression_tree.scalar import Constant as Constant + +# state_vector +from .expression_tree.state_vector import StateVectorBase as StateVectorBase +from .expression_tree.state_vector import StateVector as StateVector +from .expression_tree.state_vector import StateVectorDot as StateVectorDot + +# tensor_field +from .expression_tree.tensor_field import TensorField as TensorField + +# parameter +from .expression_tree.parameter import Parameter as Parameter +from .expression_tree.parameter import FunctionParameter as FunctionParameter + +# input_parameter +from .expression_tree.input_parameter import InputParameter as InputParameter + +# array +from .expression_tree.array import Array as Array +from .expression_tree.array import linspace as linspace +from .expression_tree.array import meshgrid as meshgrid + +# vector_field +from .expression_tree.vector_field import VectorField as VectorField + +# matrix +from .expression_tree.matrix import Matrix as Matrix + +# vector +from .expression_tree.vector import Vector as Vector + +# Config module +from . import config as config + +# LAZILY LOADED (via lazy_loader stub mechanism) + +# util +from .util import root_dir as root_dir +from .util import Timer as Timer +from .util import TimerTime as TimerTime +from .util import FuzzyDict as FuzzyDict +from .util import load as load +from .util import is_constant_and_can_evaluate as is_constant_and_can_evaluate +from .util import get_parameters_filepath as get_parameters_filepath +from .util import has_jax as has_jax +from .util import get_jax as get_jax +from .util import import_optional_dependency as import_optional_dependency + +# evaluate_python +from .expression_tree.operations.evaluate_python import find_symbols as find_symbols +from .expression_tree.operations.evaluate_python import id_to_python_variable as id_to_python_variable +from .expression_tree.operations.evaluate_python import to_python as to_python +from .expression_tree.operations.evaluate_python import EvaluatorPython as EvaluatorPython +from .expression_tree.operations.evaluate_python import EvaluatorJax as EvaluatorJax +from .expression_tree.operations.evaluate_python import JaxCooMatrix as JaxCooMatrix + +# jacobian +from .expression_tree.operations.jacobian import Jacobian as Jacobian + +# convert_to_casadi +from .expression_tree.operations.convert_to_casadi import CasadiConverter as CasadiConverter + +# unpack_symbols +from .expression_tree.operations.unpack_symbols import SymbolUnpacker as SymbolUnpacker + +# serialise +from .expression_tree.operations.serialise import Serialise as Serialise +from .expression_tree.operations.serialise import ExpressionFunctionParameter as ExpressionFunctionParameter + +# base_model +from .models.base_model import BaseModel as BaseModel +from .models.base_model import ModelSolutionObservability as ModelSolutionObservability +from .models.base_model import load_model as load_model + +# symbol_processor +from .models.symbol_processor import SymbolProcessor as SymbolProcessor + +# event +from .models.event import Event as Event +from .models.event import EventType as EventType + +# base_battery_model +from .models.full_battery_models.base_battery_model import BaseBatteryModel as BaseBatteryModel +from .models.full_battery_models.base_battery_model import BatteryModelOptions as BatteryModelOptions + +# base_submodel +from .models.submodels.base_submodel import BaseSubModel as BaseSubModel + +# geometry +from .geometry.geometry import Geometry as Geometry + +# battery_geometry +from .geometry.battery_geometry import battery_geometry as battery_geometry + +# parameter_values +from .parameters.parameter_values import ParameterValues as ParameterValues +from .parameters.parameter_values import scalarize_dict as scalarize_dict +from .parameters.parameter_values import arrayize_dict as arrayize_dict + +# geometric_parameters +from .parameters.geometric_parameters import geometric_parameters as geometric_parameters +from .parameters.geometric_parameters import GeometricParameters as GeometricParameters + +# electrical_parameters +from .parameters.electrical_parameters import electrical_parameters as electrical_parameters +from .parameters.electrical_parameters import ElectricalParameters as ElectricalParameters + +# thermal_parameters +from .parameters.thermal_parameters import thermal_parameters as thermal_parameters +from .parameters.thermal_parameters import ThermalParameters as ThermalParameters + +# lithium_ion_parameters +from .parameters.lithium_ion_parameters import LithiumIonParameters as LithiumIonParameters + +# lead_acid_parameters +from .parameters.lead_acid_parameters import LeadAcidParameters as LeadAcidParameters + +# ecm_parameters +from .parameters.ecm_parameters import EcmParameters as EcmParameters + +# size_distribution_parameters +from .parameters.size_distribution_parameters import get_size_distribution_parameters as get_size_distribution_parameters +from .parameters.size_distribution_parameters import lognormal as lognormal + +# discretisation +from .discretisations.discretisation import Discretisation as Discretisation +from .discretisations.discretisation import has_bc_of_form as has_bc_of_form + +# meshes +from .meshes.meshes import Mesh as Mesh +from .meshes.meshes import SubMesh as SubMesh +from .meshes.meshes import MeshGenerator as MeshGenerator + +# zero_dimensional_submesh +from .meshes.zero_dimensional_submesh import SubMesh0D as SubMesh0D + +# one_dimensional_submeshes +from .meshes.one_dimensional_submeshes import SubMesh1D as SubMesh1D +from .meshes.one_dimensional_submeshes import Uniform1DSubMesh as Uniform1DSubMesh +from .meshes.one_dimensional_submeshes import Exponential1DSubMesh as Exponential1DSubMesh +from .meshes.one_dimensional_submeshes import Chebyshev1DSubMesh as Chebyshev1DSubMesh +from .meshes.one_dimensional_submeshes import UserSupplied1DSubMesh as UserSupplied1DSubMesh +from .meshes.one_dimensional_submeshes import SpectralVolume1DSubMesh as SpectralVolume1DSubMesh +from .meshes.one_dimensional_submeshes import SymbolicUniform1DSubMesh as SymbolicUniform1DSubMesh + +# two_dimensional_submeshes +from .meshes.two_dimensional_submeshes import SubMesh2D as SubMesh2D +from .meshes.two_dimensional_submeshes import Uniform2DSubMesh as Uniform2DSubMesh + +# scikit_fem_submeshes +from .meshes.scikit_fem_submeshes import ScikitSubMesh2D as ScikitSubMesh2D +from .meshes.scikit_fem_submeshes import ScikitUniform2DSubMesh as ScikitUniform2DSubMesh +from .meshes.scikit_fem_submeshes import ScikitExponential2DSubMesh as ScikitExponential2DSubMesh +from .meshes.scikit_fem_submeshes import ScikitChebyshev2DSubMesh as ScikitChebyshev2DSubMesh +from .meshes.scikit_fem_submeshes import UserSupplied2DSubMesh as UserSupplied2DSubMesh + +# scikit_fem_submeshes_3d +from .meshes.scikit_fem_submeshes_3d import ScikitFemSubMesh3D as ScikitFemSubMesh3D +from .meshes.scikit_fem_submeshes_3d import ScikitFemGenerator3D as ScikitFemGenerator3D +from .meshes.scikit_fem_submeshes_3d import UserSuppliedSubmesh3D as UserSuppliedSubmesh3D + +# spatial_method +from .spatial_methods.spatial_method import SpatialMethod as SpatialMethod + +# zero_dimensional_method +from .spatial_methods.zero_dimensional_method import ZeroDimensionalSpatialMethod as ZeroDimensionalSpatialMethod + +# finite_volume +from .spatial_methods.finite_volume import FiniteVolume as FiniteVolume + +# finite_volume_2d +from .spatial_methods.finite_volume_2d import FiniteVolume2D as FiniteVolume2D + +# spectral_volume +from .spatial_methods.spectral_volume import SpectralVolume as SpectralVolume + +# scikit_finite_element +from .spatial_methods.scikit_finite_element import ScikitFiniteElement as ScikitFiniteElement + +# scikit_finite_element_3d +from .spatial_methods.scikit_finite_element_3d import ScikitFiniteElement3D as ScikitFiniteElement3D + +# solution +from .solvers.solution import Solution as Solution +from .solvers.solution import EmptySolution as EmptySolution +from .solvers.solution import make_cycle_solution as make_cycle_solution + +# processed_variable_time_integral +from .solvers.processed_variable_time_integral import ProcessedVariableTimeIntegral as ProcessedVariableTimeIntegral + +# processed_variable +from .solvers.processed_variable import ProcessedVariable as ProcessedVariable +from .solvers.processed_variable import ProcessedVariable2DFVM as ProcessedVariable2DFVM +from .solvers.processed_variable import process_variable as process_variable +from .solvers.processed_variable import ProcessedVariableUnstructured as ProcessedVariableUnstructured + +# processed_variable_computed +from .solvers.processed_variable_computed import ProcessedVariableComputed as ProcessedVariableComputed + +# summary_variable +from .solvers.summary_variable import SummaryVariables as SummaryVariables + +# base_solver +from .solvers.base_solver import BaseSolver as BaseSolver + +# dummy_solver +from .solvers.dummy_solver import DummySolver as DummySolver + +# algebraic_solver +from .solvers.algebraic_solver import AlgebraicSolver as AlgebraicSolver + +# casadi_solver +from .solvers.casadi_solver import CasadiSolver as CasadiSolver + +# casadi_algebraic_solver +from .solvers.casadi_algebraic_solver import CasadiAlgebraicSolver as CasadiAlgebraicSolver + +# scipy_solver +from .solvers.scipy_solver import ScipySolver as ScipySolver + +# composite_solver +from .solvers.composite_solver import CompositeSolver as CompositeSolver + +# jax_solver +from .solvers.jax_solver import JaxSolver as JaxSolver + +# jax_bdf_solver +from .solvers.jax_bdf_solver import jax_bdf_integrate as jax_bdf_integrate + +# idaklu_jax +from .solvers.idaklu_jax import IDAKLUJax as IDAKLUJax + +# idaklu_solver +from .solvers.idaklu_solver import IDAKLUSolver as IDAKLUSolver + +# experiment +from .experiment.experiment import Experiment as Experiment + +# quick_plot +from .plotting.quick_plot import QuickPlot as QuickPlot +from .plotting.quick_plot import close_plots as close_plots +from .plotting.quick_plot import QuickPlotAxes as QuickPlotAxes + +# plot +from .plotting.plot import plot as plot + +# plot2D +from .plotting.plot2D import plot2D as plot2D + +# plot_voltage_components +from .plotting.plot_voltage_components import plot_voltage_components as plot_voltage_components + +# plot_thermal_components +from .plotting.plot_thermal_components import plot_thermal_components as plot_thermal_components + +# plot_summary_variables +from .plotting.plot_summary_variables import plot_summary_variables as plot_summary_variables + +# dynamic_plot +from .plotting.dynamic_plot import dynamic_plot as dynamic_plot + +# plot_3d_cross_section +from .plotting.plot_3d_cross_section import plot_3d_cross_section as plot_3d_cross_section + +# plot_3d_heatmap +from .plotting.plot_3d_heatmap import plot_3d_heatmap as plot_3d_heatmap + +# simulation +from .simulation import Simulation as Simulation +from .simulation import load_sim as load_sim +from .simulation import is_notebook as is_notebook + +# batch_study +from .batch_study import BatchStudy as BatchStudy + +# pybamm_data +from .pybamm_data import DataLoader as DataLoader + +# dispatch +from .dispatch import parameter_sets as parameter_sets +from .dispatch import models as models +from .dispatch import Model as Model + +# SUBMODULE ALIASES + +from .models.submodels import active_material as active_material +from . import callbacks as callbacks +from .parameters import constants as constants +from .models.submodels import convection as convection +from .models.submodels import current_collector as current_collector +from .models.submodels import electrode as electrode +from .models.submodels import electrolyte_conductivity as electrolyte_conductivity +from .models.submodels import electrolyte_diffusion as electrolyte_diffusion +from .models.full_battery_models import equivalent_circuit as equivalent_circuit +from .models.submodels import equivalent_circuit_elements as equivalent_circuit_elements +from . import experiment as experiment +from .models.submodels import external_circuit as external_circuit +from .models.submodels import interface as interface +from .models.submodels.interface import interface_utilisation as interface_utilisation +from .models.submodels.interface import kinetics as kinetics +from .models.full_battery_models import lead_acid as lead_acid +from .models.full_battery_models import lithium_ion as lithium_ion +from .models.submodels.interface import lithium_plating as lithium_plating +from .models.submodels.interface import open_circuit_potential as open_circuit_potential +from .models.submodels import oxygen_diffusion as oxygen_diffusion +from .models.submodels import particle as particle +from .models.submodels import particle_mechanics as particle_mechanics +from .models.submodels import porosity as porosity +from .models.submodels.interface import sei as sei +from .models.full_battery_models import sodium_ion as sodium_ion +from .geometry import standard_spatial_vars as standard_spatial_vars +from .experiment import step as step +from . import telemetry as telemetry +from .models.submodels import thermal as thermal +from .models.submodels import transport_efficiency as transport_efficiency diff --git a/src/pybamm/_lazy_config.py b/src/pybamm/_lazy_config.py new file mode 100644 index 0000000000..98036aa13f --- /dev/null +++ b/src/pybamm/_lazy_config.py @@ -0,0 +1,367 @@ +""" +Single source of truth for lazy loading configuration. + +Used by both __init__.py (runtime) and generate_pyi_stub.py (stub generation). + +This module defines: +- EAGER_IMPORTS: Attributes imported at module load time (wildcard imports) +- LAZY_IMPORTS: Attributes loaded on first access via lazy_loader +- SUBMODULE_ALIASES: Nested submodules exposed at top level (e.g., pybamm.lithium_ion) +""" + +# Eagerly loaded items (imported at module load time via wildcard imports) +EAGER_IMPORTS: dict[str, list[str]] = { + # Logger and settings + ".logger": ["logger", "set_logging_level", "get_new_logger"], + ".settings": ["settings"], + ".citations": ["Citations", "citations", "print_citations"], + # Expression tree modules (wildcard imports in __init__.py) + ".expression_tree.symbol": [ + "Symbol", + "domain_size", + "create_object_of_size", + "evaluate_for_shape_using_domain", + "is_constant", + "is_scalar_zero", + "is_scalar_one", + "is_scalar_minus_one", + "is_matrix_zero", + "is_matrix_one", + "is_matrix_minus_one", + "simplify_if_constant", + "convert_to_symbol", + ], + ".expression_tree.binary_operators": [ + "BinaryOperator", + "Power", + "Addition", + "Subtraction", + "Multiplication", + "KroneckerProduct", + "TensorProduct", + "MatrixMultiplication", + "Division", + "Inner", + "Equality", + "EqualHeaviside", + "NotEqualHeaviside", + "Modulo", + "Minimum", + "Maximum", + "softplus", + "softminus", + "sigmoid", + "source", + ], + ".expression_tree.concatenations": [ + "Concatenation", + "NumpyConcatenation", + "DomainConcatenation", + "SparseStack", + "ConcatenationVariable", + "concatenation", + "numpy_concatenation", + ], + ".expression_tree.unary_operators": [ + "UnaryOperator", + "Negate", + "AbsoluteValue", + "Transpose", + "Sign", + "Floor", + "Ceiling", + "Index", + "SpatialOperator", + "Gradient", + "Divergence", + "Laplacian", + "GradientSquared", + "Mass", + "BoundaryMass", + "Integral", + "BaseIndefiniteIntegral", + "IndefiniteIntegral", + "BackwardIndefiniteIntegral", + "DefiniteIntegralVector", + "BoundaryIntegral", + "OneDimensionalIntegral", + "DeltaFunction", + "BoundaryOperator", + "BoundaryValue", + "BoundaryMeshSize", + "ExplicitTimeIntegral", + "BoundaryGradient", + "EvaluateAt", + "UpwindDownwind", + "UpwindDownwind2D", + "NodeToEdge2D", + "Magnitude", + "Upwind", + "Downwind", + "NotConstant", + "grad", + "div", + "laplacian", + "grad_squared", + "surf", + "boundary_value", + "sign", + "upwind", + "downwind", + ], + ".expression_tree.averages": [ + "_BaseAverage", + "XAverage", + "YZAverage", + "ZAverage", + "RAverage", + "SizeAverage", + "x_average", + "yz_average", + "z_average", + "r_average", + "size_average", + ], + ".expression_tree.broadcasts": [ + "Broadcast", + "PrimaryBroadcast", + "PrimaryBroadcastToEdges", + "SecondaryBroadcast", + "SecondaryBroadcastToEdges", + "TertiaryBroadcast", + "TertiaryBroadcastToEdges", + "FullBroadcast", + "FullBroadcastToEdges", + "ones_like", + "zeros_like", + ], + ".expression_tree.functions": [ + "Function", + "SpecificFunction", + "Arcsinh", + "Arctan", + "Cos", + "Cosh", + "Erf", + "Exp", + "Log", + "Max", + "Min", + "Sin", + "Sinh", + "Sqrt", + "Tanh", + "arcsinh", + "arctan", + "cos", + "cosh", + "erf", + "exp", + "log", + "log10", + "max", + "min", + "sin", + "sinh", + "sqrt", + "tanh", + ], + ".expression_tree.interpolant": ["Interpolant"], + ".expression_tree.discrete_time_sum": ["DiscreteTimeData", "DiscreteTimeSum"], + ".expression_tree.variable": ["VariableBase", "Variable", "VariableDot"], + ".expression_tree.coupled_variable": ["CoupledVariable"], + ".expression_tree.independent_variable": [ + "IndependentVariable", + "Time", + "SpatialVariable", + "SpatialVariableEdge", + "t", + "KNOWN_COORD_SYS", + ], + ".expression_tree.exceptions": [ + "DomainError", + "OptionError", + "OptionWarning", + "GeometryError", + "ModelError", + "SolverError", + "SolverWarning", + "ShapeError", + "ModelWarning", + "DiscretisationError", + "InvalidModelJSONError", + ], + ".expression_tree.scalar": ["Scalar", "Constant"], + ".expression_tree.state_vector": [ + "StateVectorBase", + "StateVector", + "StateVectorDot", + ], + ".expression_tree.tensor_field": ["TensorField"], + ".expression_tree.parameter": ["Parameter", "FunctionParameter"], + ".expression_tree.input_parameter": ["InputParameter"], + ".expression_tree.array": ["Array", "linspace", "meshgrid"], + ".expression_tree.vector_field": ["VectorField"], + ".expression_tree.matrix": ["Matrix"], + ".expression_tree.vector": ["Vector"], +} + +# Lazily loaded attributes (loaded on first access via lazy_loader) +LAZY_IMPORTS: dict[str, list[str]] = { + ".util": [ + "root_dir", + "Timer", + "TimerTime", + "FuzzyDict", + "load", + "is_constant_and_can_evaluate", + "get_parameters_filepath", + "has_jax", + "get_jax", + "import_optional_dependency", + ], + ".expression_tree.operations.evaluate_python": [ + "find_symbols", + "id_to_python_variable", + "to_python", + "EvaluatorPython", + "EvaluatorJax", + "JaxCooMatrix", + ], + ".expression_tree.operations.jacobian": ["Jacobian"], + ".expression_tree.operations.convert_to_casadi": ["CasadiConverter"], + ".expression_tree.operations.unpack_symbols": ["SymbolUnpacker"], + ".expression_tree.operations.serialise": [ + "Serialise", + "ExpressionFunctionParameter", + ], + ".models.base_model": ["BaseModel", "ModelSolutionObservability", "load_model"], + ".models.symbol_processor": ["SymbolProcessor"], + ".models.event": ["Event", "EventType"], + ".models.full_battery_models.base_battery_model": [ + "BaseBatteryModel", + "BatteryModelOptions", + ], + ".models.submodels.base_submodel": ["BaseSubModel"], + ".geometry.geometry": ["Geometry"], + ".geometry.battery_geometry": ["battery_geometry"], + ".parameters.parameter_values": [ + "ParameterValues", + "scalarize_dict", + "arrayize_dict", + ], + ".parameters.geometric_parameters": ["geometric_parameters", "GeometricParameters"], + ".parameters.electrical_parameters": [ + "electrical_parameters", + "ElectricalParameters", + ], + ".parameters.thermal_parameters": ["thermal_parameters", "ThermalParameters"], + ".parameters.lithium_ion_parameters": ["LithiumIonParameters"], + ".parameters.lead_acid_parameters": ["LeadAcidParameters"], + ".parameters.ecm_parameters": ["EcmParameters"], + ".parameters.size_distribution_parameters": [ + "get_size_distribution_parameters", + "lognormal", + ], + ".discretisations.discretisation": ["Discretisation", "has_bc_of_form"], + ".meshes.meshes": ["Mesh", "SubMesh", "MeshGenerator"], + ".meshes.zero_dimensional_submesh": ["SubMesh0D"], + ".meshes.one_dimensional_submeshes": [ + "SubMesh1D", + "Uniform1DSubMesh", + "Exponential1DSubMesh", + "Chebyshev1DSubMesh", + "UserSupplied1DSubMesh", + "SpectralVolume1DSubMesh", + "SymbolicUniform1DSubMesh", + ], + ".meshes.two_dimensional_submeshes": ["SubMesh2D", "Uniform2DSubMesh"], + ".meshes.scikit_fem_submeshes": [ + "ScikitSubMesh2D", + "ScikitUniform2DSubMesh", + "ScikitExponential2DSubMesh", + "ScikitChebyshev2DSubMesh", + "UserSupplied2DSubMesh", + ], + ".meshes.scikit_fem_submeshes_3d": [ + "ScikitFemSubMesh3D", + "ScikitFemGenerator3D", + "UserSuppliedSubmesh3D", + ], + ".spatial_methods.spatial_method": ["SpatialMethod"], + ".spatial_methods.zero_dimensional_method": ["ZeroDimensionalSpatialMethod"], + ".spatial_methods.finite_volume": ["FiniteVolume"], + ".spatial_methods.finite_volume_2d": ["FiniteVolume2D"], + ".spatial_methods.spectral_volume": ["SpectralVolume"], + ".spatial_methods.scikit_finite_element": ["ScikitFiniteElement"], + ".spatial_methods.scikit_finite_element_3d": ["ScikitFiniteElement3D"], + ".solvers.solution": ["Solution", "EmptySolution", "make_cycle_solution"], + ".solvers.processed_variable_time_integral": ["ProcessedVariableTimeIntegral"], + ".solvers.processed_variable": [ + "ProcessedVariable", + "ProcessedVariable2DFVM", + "process_variable", + "ProcessedVariableUnstructured", + ], + ".solvers.processed_variable_computed": ["ProcessedVariableComputed"], + ".solvers.summary_variable": ["SummaryVariables"], + ".solvers.base_solver": ["BaseSolver"], + ".solvers.dummy_solver": ["DummySolver"], + ".solvers.algebraic_solver": ["AlgebraicSolver"], + ".solvers.casadi_solver": ["CasadiSolver"], + ".solvers.casadi_algebraic_solver": ["CasadiAlgebraicSolver"], + ".solvers.scipy_solver": ["ScipySolver"], + ".solvers.composite_solver": ["CompositeSolver"], + ".solvers.jax_solver": ["JaxSolver"], + ".solvers.jax_bdf_solver": ["jax_bdf_integrate"], + ".solvers.idaklu_jax": ["IDAKLUJax"], + ".solvers.idaklu_solver": ["IDAKLUSolver"], + ".experiment.experiment": ["Experiment"], + ".plotting.quick_plot": ["QuickPlot", "close_plots", "QuickPlotAxes"], + ".plotting.plot": ["plot"], + ".plotting.plot2D": ["plot2D"], + ".plotting.plot_voltage_components": ["plot_voltage_components"], + ".plotting.plot_thermal_components": ["plot_thermal_components"], + ".plotting.plot_summary_variables": ["plot_summary_variables"], + ".plotting.dynamic_plot": ["dynamic_plot"], + ".plotting.plot_3d_cross_section": ["plot_3d_cross_section"], + ".plotting.plot_3d_heatmap": ["plot_3d_heatmap"], + ".simulation": ["Simulation", "load_sim", "is_notebook"], + ".batch_study": ["BatchStudy"], + ".pybamm_data": ["DataLoader"], + ".dispatch": ["parameter_sets", "models", "Model"], +} + +# Submodule aliases - nested submodules exposed at top level +SUBMODULE_ALIASES: dict[str, str] = { + "lead_acid": ".models.full_battery_models.lead_acid", + "lithium_ion": ".models.full_battery_models.lithium_ion", + "equivalent_circuit": ".models.full_battery_models.equivalent_circuit", + "sodium_ion": ".models.full_battery_models.sodium_ion", + "active_material": ".models.submodels.active_material", + "convection": ".models.submodels.convection", + "current_collector": ".models.submodels.current_collector", + "electrolyte_conductivity": ".models.submodels.electrolyte_conductivity", + "electrolyte_diffusion": ".models.submodels.electrolyte_diffusion", + "electrode": ".models.submodels.electrode", + "external_circuit": ".models.submodels.external_circuit", + "interface": ".models.submodels.interface", + "oxygen_diffusion": ".models.submodels.oxygen_diffusion", + "particle": ".models.submodels.particle", + "porosity": ".models.submodels.porosity", + "thermal": ".models.submodels.thermal", + "transport_efficiency": ".models.submodels.transport_efficiency", + "particle_mechanics": ".models.submodels.particle_mechanics", + "equivalent_circuit_elements": ".models.submodels.equivalent_circuit_elements", + "kinetics": ".models.submodels.interface.kinetics", + "sei": ".models.submodels.interface.sei", + "lithium_plating": ".models.submodels.interface.lithium_plating", + "interface_utilisation": ".models.submodels.interface.interface_utilisation", + "open_circuit_potential": ".models.submodels.interface.open_circuit_potential", + "standard_spatial_vars": ".geometry.standard_spatial_vars", + "constants": ".parameters.constants", + "experiment": ".experiment", + "step": ".experiment.step", + "callbacks": ".callbacks", + "telemetry": ".telemetry", +} diff --git a/src/pybamm/citations.py b/src/pybamm/citations.py index 58a0ffede0..75e5e6d714 100644 --- a/src/pybamm/citations.py +++ b/src/pybamm/citations.py @@ -32,7 +32,9 @@ class Citations: def __init__(self): self._check_for_bibtex() # Dict mapping citations keys to BibTex entries - self._all_citations: dict[str, str] = dict() + self._all_citations: dict = dict() + # Cache for string representations (populated lazily) + self._citation_strings: dict[str, str] = dict() self.read_citations() self._reset() @@ -82,13 +84,25 @@ def _add_citation(self, key, entry): if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() - # Warn if overwriting a previous citation - new_citation = entry.to_string("bibtex") - if key in self._all_citations and new_citation != self._all_citations[key]: - warnings.warn(f"Replacing citation for {key}", stacklevel=2) - - # Add to database - self._all_citations[key] = new_citation + # Store entry object -- defer to_string until citation is actually used + if key in self._all_citations: + # Only warn if actually different (compare lazily) + old_str = self._get_citation_string(key) + new_str = entry.to_string("bibtex") + if new_str != old_str: + warnings.warn(f"Replacing citation for {key}", stacklevel=2) + self._citation_strings[key] = new_str + + # Add entry object to database + self._all_citations[key] = entry + + def _get_citation_string(self, key): + """Get the BibTeX string for a citation, caching the result.""" + if key not in self._citation_strings: + entry = self._all_citations.get(key) + if entry is not None: + self._citation_strings[key] = entry.to_string("bibtex") + return self._citation_strings.get(key) def _add_citation_tag(self, key, entry): """Adds a tag for a citation key in the dict, which represents the name of the @@ -98,7 +112,7 @@ class that called :meth:`register`""" @property def _cited(self): """Return a list of the BibTeX entries that have been cited""" - return [self._all_citations[key] for key in self._papers_to_cite] + return [self._get_citation_string(key) for key in self._papers_to_cite] def register(self, key): """Register a paper to be cited, one at a time. The intended use is that diff --git a/src/pybamm/expression_tree/interpolant.py b/src/pybamm/expression_tree/interpolant.py index 9324c489e8..7ea8f01046 100644 --- a/src/pybamm/expression_tree/interpolant.py +++ b/src/pybamm/expression_tree/interpolant.py @@ -9,7 +9,6 @@ import numpy as np import numpy.typing as npt -from scipy import interpolate import pybamm @@ -125,6 +124,8 @@ def __init__( ) # Create interpolating function + from scipy import interpolate + if len(x) == 1: self.dimension = 1 if interpolator == "linear": diff --git a/src/pybamm/expression_tree/operations/evaluate_python.py b/src/pybamm/expression_tree/operations/evaluate_python.py index 897e0c3702..4634334554 100644 --- a/src/pybamm/expression_tree/operations/evaluate_python.py +++ b/src/pybamm/expression_tree/operations/evaluate_python.py @@ -13,11 +13,7 @@ import pybamm if pybamm.has_jax(): - import jax - - platform = jax.lib.xla_bridge.get_backend().platform.casefold() - if platform != "metal": - jax.config.update("jax_enable_x64", True) + jax = pybamm.get_jax() class JaxCooMatrix: diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 269e881929..90172fedad 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -12,7 +12,6 @@ from enum import Enum from pathlib import Path -import black import numpy as np import pybamm @@ -58,6 +57,8 @@ def to_source(self): src += f" return {expression.to_equation()}" + import black + formatted_src = black.format_str(src, mode=black.FileMode()) return formatted_src diff --git a/src/pybamm/solvers/casadi_solver.py b/src/pybamm/solvers/casadi_solver.py index 12bf01dad8..dcd3ec2cd5 100644 --- a/src/pybamm/solvers/casadi_solver.py +++ b/src/pybamm/solvers/casadi_solver.py @@ -2,7 +2,6 @@ import casadi import numpy as np -from scipy.interpolate import interp1d import pybamm @@ -424,6 +423,8 @@ def integer_bisect(): t_event = np.nanmin(t_events) # create interpolant to evaluate y in the current integration # window + from scipy.interpolate import interp1d + y_sol = interp1d(sol.t, sol.y, kind="linear") y_event = y_sol(t_event) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index 3a0aaee94d..afff2e2555 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -15,7 +15,7 @@ logger = logging.getLogger("pybamm.solvers.idaklu_jax") if pybamm.has_jax(): - import jax + jax = pybamm.get_jax() from jax import lax try: diff --git a/src/pybamm/solvers/jax_bdf_solver.py b/src/pybamm/solvers/jax_bdf_solver.py index 9263b5fbd4..aa363f0028 100644 --- a/src/pybamm/solvers/jax_bdf_solver.py +++ b/src/pybamm/solvers/jax_bdf_solver.py @@ -10,7 +10,7 @@ if pybamm.has_jax(): import functools - import jax + jax = pybamm.get_jax() import jax.numpy as jnp from jax import core, dtypes from jax.api_util import flatten_fun_nokwargs @@ -32,10 +32,6 @@ def split_list(lst, indices): result.append(lst[start:]) return result - platform = jax.lib.xla_bridge.get_backend().platform.casefold() - if platform != "metal": - jax.config.update("jax_enable_x64", True) - MAX_ORDER = 5 NEWTON_MAXITER = 4 ROOT_SOLVE_MAXITER = 15 diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index e1d3354027..d1a7c9ae25 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -5,7 +5,7 @@ import pybamm if pybamm.has_jax(): - import jax + jax = pybamm.get_jax() import jax.numpy as jnp from jax.experimental.ode import odeint diff --git a/src/pybamm/solvers/processed_variable.py b/src/pybamm/solvers/processed_variable.py index 6c83857ec9..50e37c4241 100644 --- a/src/pybamm/solvers/processed_variable.py +++ b/src/pybamm/solvers/processed_variable.py @@ -2,8 +2,6 @@ import casadi import numpy as np -import xarray as xr -from pybammsolvers import idaklu import pybamm @@ -122,6 +120,8 @@ def observe_raw(self): return self._observe_postfix(self._observe_raw(), t) def _setup_inputs(self, t, full_range): + import pybammsolvers.idaklu as idaklu + pybamm.logger.debug("Setting up C++ interpolation inputs") ts = self.all_ts @@ -167,6 +167,8 @@ def _setup_inputs(self, t, full_range): return ts, ys, yps, funcs, inputs, is_f_contiguous def _observe_hermite(self, t): + import pybammsolvers.idaklu as idaklu + pybamm.logger.debug("Observing and Hermite interpolating the variable") ts, ys, yps, funcs, inputs, _ = self._setup_inputs(t, full_range=False) @@ -174,6 +176,8 @@ def _observe_hermite(self, t): return idaklu.observe_hermite_interp(t, ts, ys, yps, inputs, funcs, shapes) def _observe_raw(self): + import pybammsolvers.idaklu as idaklu + pybamm.logger.debug("Observing the variable raw data") t = self.t_pts ts, ys, _, funcs, inputs, is_f_contiguous = self._setup_inputs( @@ -310,6 +314,8 @@ def _xr_interpolate( Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R), using interpolation """ + import xarray as xr + if observe_raw: if not self.xr_array_raw_initialized: self._xr_array_raw = xr.DataArray(entries_for_interp, coords=coords) diff --git a/src/pybamm/solvers/processed_variable_computed.py b/src/pybamm/solvers/processed_variable_computed.py index b3a47b4340..9d6e8d0edb 100644 --- a/src/pybamm/solvers/processed_variable_computed.py +++ b/src/pybamm/solvers/processed_variable_computed.py @@ -5,8 +5,6 @@ import casadi import numpy as np -import xarray as xr -from scipy.integrate import cumulative_trapezoid import pybamm @@ -245,9 +243,13 @@ def initialise_time_independent(self): self.dimensions = 0 def initialise_0D(self): + import xarray as xr + entries = self.unroll_0D() if self.cumtrapz_ic is not None: + from scipy.integrate import cumulative_trapezoid + entries = cumulative_trapezoid( entries, self.t_pts, initial=float(self.cumtrapz_ic) ) @@ -259,6 +261,8 @@ def initialise_0D(self): self.dimensions = 0 def initialise_1D(self): + import xarray as xr + entries = self.unroll_1D() # Get node and edge values @@ -319,6 +323,8 @@ def initialise_2D(self): """ Initialise a 2D object that depends on x and r, x and z, x and R, or R and r. """ + import xarray as xr + first_dim_nodes = self.mesh.nodes first_dim_edges = self.mesh.edges second_dim_nodes = self.base_variables[0].secondary_mesh.nodes @@ -450,6 +456,8 @@ def initialise_2D(self): ) def initialise_2D_scikit_fem(self): + import xarray as xr + y_sol = self.mesh.edges["y"] len_y = len(y_sol) z_sol = self.mesh.edges["z"] @@ -481,6 +489,8 @@ def initialise_3D(self): """ Initialise a 3D object that depends on x, r, and R. """ + import xarray as xr + first_dim_nodes = self.mesh.nodes first_dim_edges = self.mesh.edges second_dim_nodes = self.base_variables[0].secondary_mesh.nodes @@ -629,6 +639,8 @@ def initialise_3D(self): ) def initialise_3D_scikit_fem(self): + import xarray as xr + x_nodes = self.mesh.nodes x_edges = self.mesh.edges y_sol = self.base_variables[0].secondary_mesh.edges["y"] diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index 4c92a4216c..de0458534b 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -11,8 +11,6 @@ import casadi import numpy as np -import pandas as pd -from scipy.io import savemat import pybamm @@ -778,6 +776,8 @@ def save_data( "['Electrolyte concentration'], to_format='matlab, " "short_names={'Electrolyte concentration': 'c_e'})" ) + from scipy.io import savemat + savemat(filename, data) elif to_format == "csv": for name, var in data.items(): @@ -785,6 +785,8 @@ def save_data( raise ValueError( f"only 0D variables can be saved to csv, but '{name}' is {var.ndim - 1}D" ) + import pandas as pd + df = pd.DataFrame(data) return df.to_csv(filename, index=False) elif to_format == "json": diff --git a/src/pybamm/telemetry.py b/src/pybamm/telemetry.py index 71957e9bd2..09642a38da 100644 --- a/src/pybamm/telemetry.py +++ b/src/pybamm/telemetry.py @@ -1,42 +1,56 @@ import sys -from posthog import Posthog - import pybamm +# Lazily initialized posthog client +_posthog = None +_disabled = False + -class MockTelemetry: - def __init__(self): - self.disabled = True +def _get_posthog(): + """Lazily initialize the posthog client on first use.""" + global _posthog, _disabled - @staticmethod - def capture(**kwargs): # pragma: no cover - pass + if _posthog is not None: + return _posthog + + if pybamm.config.check_opt_out(): + _disabled = True + return None + # Import posthog only when needed (this pulls in requests, urllib3, etc.) + from posthog import Posthog -if pybamm.config.check_opt_out(): - _posthog = MockTelemetry() -else: # pragma: no cover _posthog = Posthog( # this is the public, write only API key, so it's ok to include it here project_api_key="phc_acTt7KxmvBsAxaE0NyRd5WfJyNxGvBq1U9HnlQSztmb", host="https://us.i.posthog.com", ) _posthog.log.setLevel("CRITICAL") + return _posthog def disable(): - _posthog.disabled = True + global _disabled + _disabled = True + if _posthog is not None: + _posthog.disabled = True def capture(event): # pragma: no cover - if pybamm.config.is_running_tests() or _posthog.disabled: + global _disabled + + if pybamm.config.is_running_tests() or _disabled: return if pybamm.config.check_opt_out(): disable() return + posthog = _get_posthog() + if posthog is None: + return + config = pybamm.config.read() if config: properties = { @@ -44,4 +58,4 @@ def capture(event): # pragma: no cover "pybamm_version": pybamm.__version__, } user_id = config["uuid"] - _posthog.capture(distinct_id=user_id, event=event, properties=properties) + posthog.capture(distinct_id=user_id, event=event, properties=properties) diff --git a/src/pybamm/util.py b/src/pybamm/util.py index 6149d06139..0b2ef3161c 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -362,6 +362,58 @@ def has_jax(): ) +# Track whether JAX has been configured +_jax_configured = False + + +def get_jax(): + """ + Import and configure JAX for PyBaMM use. + + This function lazily imports JAX and configures it for float64 precision + (except on Metal backend where float64 is not supported). The configuration + is only applied once, on first call. + + Returns + ------- + module or None + The configured jax module if available, None otherwise + + Raises + ------ + ModuleNotFoundError + If JAX is not installed (use has_jax() to check first) + + Examples + -------- + >>> if pybamm.has_jax(): + ... jax = pybamm.get_jax() + ... # JAX is now configured and ready to use + """ + global _jax_configured + + if not has_jax(): + raise ModuleNotFoundError( + "JAX is not installed. See " + "https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver" + ) + + import jax + + if not _jax_configured: + try: + platform = jax.lib.xla_bridge.get_backend().platform.casefold() + if platform != "metal": + jax.config.update("jax_enable_x64", True) + except Exception: + # If we can't get the backend + # try to enable x64 + jax.config.update("jax_enable_x64", True) + _jax_configured = True + + return jax + + def is_constant_and_can_evaluate(symbol): """ Returns True if symbol is constant and evaluation does not raise any errors. diff --git a/tests/unit/test_lazy_imports.py b/tests/unit/test_lazy_imports.py new file mode 100644 index 0000000000..6e78c6a169 --- /dev/null +++ b/tests/unit/test_lazy_imports.py @@ -0,0 +1,300 @@ +""" +Tests for the lazy import mechanism in PyBaMM using lazy_loader. +""" + +import types + +import pytest + + +class TestLazyImports: + """Tests for lazy import functionality.""" + + def test_lazy_imports_resolve(self): + """Test that lazy imports can be resolved via getattr.""" + import pybamm + + # Test a selection of lazy imports + lazy_imports = [ + "CasadiSolver", + "Simulation", + "ParameterValues", + "Mesh", + "Solution", + "Timer", + "FuzzyDict", + "root_dir", + "lithium_ion", + "lead_acid", + "experiment", + ] + + failed_imports = [] + for name in lazy_imports: + try: + attr = getattr(pybamm, name) + assert attr is not None, f"{name} resolved to None" + except (ImportError, AttributeError) as e: + failed_imports.append((name, str(e))) + + if failed_imports: + msg = "\n".join(f" {name}: {error}" for name, error in failed_imports) + pytest.fail(f"Failed to resolve lazy imports:\n{msg}") + + def test_caching_works(self): + """Test that accessing an attribute twice returns the same object.""" + import pybamm + + # Access a lazy attribute twice + first_access = pybamm.CasadiSolver + second_access = pybamm.CasadiSolver + + # Should be the exact same object (cached) + assert first_access is second_access + + def test_dir_includes_lazy_imports(self): + """Test that dir(pybamm) includes lazy import names.""" + import pybamm + + pybamm_dir = dir(pybamm) + + # Check some known lazy imports are in dir() + expected_in_dir = [ + "CasadiSolver", + "Simulation", + "ParameterValues", + "Mesh", + "Solution", + "lithium_ion", + "lead_acid", + ] + + missing = [] + for name in expected_in_dir: + if name not in pybamm_dir: + missing.append(name) + + if missing: + pytest.fail(f"Missing from dir(pybamm): {missing}") + + def test_undefined_attribute_raises_error(self): + """Test that accessing undefined attributes raises AttributeError.""" + import pybamm + + with pytest.raises( + AttributeError, match="this_attribute_definitely_does_not_exist_xyz123" + ): + _ = pybamm.this_attribute_definitely_does_not_exist_xyz123 + + def test_lazy_module_imports(self): + """Test that lazy module imports return ModuleType.""" + import pybamm + + module_imports = [ + "lithium_ion", + "lead_acid", + "experiment", + "callbacks", + "telemetry", + "constants", + ] + + for name in module_imports: + attr = getattr(pybamm, name) + assert isinstance(attr, types.ModuleType), ( + f"{name} should be a module, got {type(attr)}" + ) + + def test_lazy_class_imports(self): + """Test that lazy class imports return type (class objects).""" + import pybamm + + # Test some known class imports + class_names = [ + "CasadiSolver", + "Simulation", + "ParameterValues", + "Mesh", + "Solution", + ] + + for name in class_names: + attr = getattr(pybamm, name) + assert isinstance(attr, type), f"{name} should be a class, got {type(attr)}" + + def test_known_coord_sys_accessible(self): + """Test that KNOWN_COORD_SYS is accessible (from eager import).""" + import pybamm + + # Should be accessible without error + coord_sys = pybamm.KNOWN_COORD_SYS + assert coord_sys is not None + assert isinstance(coord_sys, set | frozenset | list | tuple) + + def test_t_accessible(self): + """Test that pybamm.t is accessible (from eager import).""" + import pybamm + + # Should be accessible without error + t = pybamm.t + assert t is not None + + def test_size_distribution_parameters_accessible(self): + """Test that size distribution parameters (from wildcard module) are accessible.""" + import pybamm + + # These come from the size_distribution_parameters module via the stub + assert hasattr(pybamm, "get_size_distribution_parameters") + assert hasattr(pybamm, "lognormal") + + get_size_dist = pybamm.get_size_distribution_parameters + lognormal = pybamm.lognormal + + assert callable(get_size_dist) + assert callable(lognormal) + + +class TestThreadSafety: + """Tests for thread-safe lazy loading.""" + + def test_concurrent_access(self): + """Test that concurrent access to lazy imports is thread-safe.""" + import threading + + import pybamm + + errors = [] + results = [] + lock = threading.Lock() + + def access_attr(): + try: + # Access various lazy attributes + solver = pybamm.CasadiSolver + sim = pybamm.Simulation + param = pybamm.ParameterValues + + with lock: + results.append((solver, sim, param)) + except Exception as e: + with lock: + errors.append(e) + + # Create multiple threads + threads = [threading.Thread(target=access_attr) for _ in range(50)] + + # Start all threads + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check no errors occurred + assert not errors, f"Errors during concurrent access: {errors}" + + # Check all threads got the same objects (due to caching) + if results: + first_result = results[0] + for result in results[1:]: + assert result[0] is first_result[0], "CasadiSolver not cached properly" + assert result[1] is first_result[1], "Simulation not cached properly" + assert result[2] is first_result[2], ( + "ParameterValues not cached properly" + ) + + +class TestJaxConfiguration: + """Tests for JAX lazy configuration.""" + + def test_get_jax_available(self): + """Test that get_jax is accessible.""" + import pybamm + + assert hasattr(pybamm, "get_jax") + assert callable(pybamm.get_jax) + + def test_get_jax_returns_none_when_unavailable(self): + """Test that get_jax raises ModuleNotFoundError when JAX is not installed.""" + import pybamm + + if not pybamm.has_jax(): + with pytest.raises(ModuleNotFoundError): + pybamm.get_jax() + + @pytest.mark.skipif(not __import__("pybamm").has_jax(), reason="JAX not installed") + def test_get_jax_returns_jax_module(self): + """Test that get_jax returns the jax module when available.""" + import pybamm + + jax = pybamm.get_jax() + assert jax is not None + assert hasattr(jax, "numpy") + assert hasattr(jax, "config") + + @pytest.mark.skipif(not __import__("pybamm").has_jax(), reason="JAX not installed") + def test_get_jax_configures_x64(self): + """Test that get_jax configures JAX for float64 precision.""" + import pybamm + + jax = pybamm.get_jax() + # After calling get_jax, x64 should be enabled (unless on Metal) + platform = jax.lib.xla_bridge.get_backend().platform.casefold() + if platform != "metal": + assert jax.config.x64_enabled, "JAX x64 should be enabled after get_jax()" + + @pytest.mark.skipif(not __import__("pybamm").has_jax(), reason="JAX not installed") + def test_get_jax_caches_configuration(self): + """Test that get_jax only configures JAX once.""" + import pybamm + from pybamm import util + + # Reset the configuration flag for testing + original_configured = util._jax_configured + util._jax_configured = False + + try: + # First call should configure + jax1 = pybamm.get_jax() + assert util._jax_configured is True + + # Second call should return same module without reconfiguring + jax2 = pybamm.get_jax() + assert jax1 is jax2 + finally: + # Restore original state + util._jax_configured = original_configured + + +class TestStubFileIntegrity: + """Tests to verify the stub file is properly configured.""" + + def test_all_is_list(self): + """Test that __all__ is a list.""" + import pybamm + + assert isinstance(pybamm.__all__, list) + + def test_all_contains_expected_items(self): + """Test that __all__ contains expected items.""" + import pybamm + + expected_items = [ + "CasadiSolver", + "Simulation", + "ParameterValues", + "Solution", + "Experiment", + "lithium_ion", + ] + + for item in expected_items: + assert item in pybamm.__all__, f"{item} not in __all__" + + def test_version_accessible(self): + """Test that __version__ is accessible.""" + import pybamm + + assert hasattr(pybamm, "__version__") + assert isinstance(pybamm.__version__, str)