diff --git a/examples/scripts/SPM_compare_particle_grid.py b/examples/scripts/SPM_compare_particle_grid.py index 44a6f84edb..0fa7cdb6c1 100644 --- a/examples/scripts/SPM_compare_particle_grid.py +++ b/examples/scripts/SPM_compare_particle_grid.py @@ -2,6 +2,7 @@ # Compare different discretisations in the particle # import argparse +from typing import Any import numpy as np import pybamm import matplotlib.pyplot as plt @@ -48,7 +49,7 @@ disc.process_model(model) # solve model -solutions = [None] * len(models) +solutions: list[Any] = [None] * len(models) t_eval = np.linspace(0, 3600, 100) for i, model in enumerate(models): solutions[i] = model.default_solver.solve(model, t_eval) diff --git a/examples/scripts/SPMe_step.py b/examples/scripts/SPMe_step.py index f277c0e790..56ff16d2b2 100644 --- a/examples/scripts/SPMe_step.py +++ b/examples/scripts/SPMe_step.py @@ -43,13 +43,14 @@ time += dt # plot -time_in_seconds = solution["Time [s]"].entries -step_time_in_seconds = step_solution["Time [s]"].entries -voltage = solution["Voltage [V]"].entries -step_voltage = step_solution["Voltage [V]"].entries -plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)") -plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)") -plt.xlabel(r"$t$") -plt.ylabel("Voltage [V]") -plt.legend() -plt.show() +if step_solution is not None: + time_in_seconds = solution["Time [s]"].entries + step_time_in_seconds = step_solution["Time [s]"].entries + voltage = solution["Voltage [V]"].entries + step_voltage = step_solution["Voltage [V]"].entries + plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)") + plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)") + plt.xlabel(r"$t$") + plt.ylabel("Voltage [V]") + plt.legend() + plt.show() diff --git a/examples/scripts/heat_equation.py b/examples/scripts/heat_equation.py index fd01b37f97..4c2ac99ca4 100644 --- a/examples/scripts/heat_equation.py +++ b/examples/scripts/heat_equation.py @@ -5,6 +5,7 @@ import pybamm import numpy as np import matplotlib.pyplot as plt +import numpy.typing as npt # Numerical solution ---------------------------------------------------------- @@ -106,7 +107,7 @@ def T_exact(x, t): # Plot ------------------------------------------------------------------------ x_nodes = mesh["rod"].nodes # numerical gridpoints xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution -plot_times = np.linspace(0, 1, 5) +plot_times: npt.NDArray = np.linspace(0, 1, 5) plt.figure(figsize=(15, 8)) cmap = plt.get_cmap("inferno") diff --git a/examples/scripts/minimal_example_of_lookup_tables.py b/examples/scripts/minimal_example_of_lookup_tables.py index 335e9961ac..8ceda74b23 100644 --- a/examples/scripts/minimal_example_of_lookup_tables.py +++ b/examples/scripts/minimal_example_of_lookup_tables.py @@ -34,12 +34,12 @@ def process_2D(name, data): D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df) -def D_s_n(sto, T): +def D_s_n_func(sto, T): name, (x, y) = D_s_n_data return pybamm.Interpolant(x, y, [T, sto], name) -parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n +parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"] diff --git a/pyproject.toml b/pyproject.toml index 4d19bd304d..9a78e168f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -263,6 +263,10 @@ concurrency = ["multiprocessing"] ignore_missing_imports = true allow_redefinition = true disable_error_code = ["call-overload", "operator"] +strict = false +warn_unreachable = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +exclude = "^(build/|docs/conf\\.py)$" [[tool.mypy.overrides]] module = [ diff --git a/src/pybamm/experiment/experiment.py b/src/pybamm/experiment/experiment.py index ce44457cb2..b9cd4b4d9d 100644 --- a/src/pybamm/experiment/experiment.py +++ b/src/pybamm/experiment/experiment.py @@ -40,7 +40,9 @@ class Experiment: def __init__( self, - operating_conditions: list[str | tuple[str] | BaseStep], + operating_conditions: list[ + str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep + ], period: str | None = None, temperature: float | None = None, termination: list[str] | None = None, diff --git a/src/pybamm/experiment/step/base_step.py b/src/pybamm/experiment/step/base_step.py index 0895cdfaa6..46a367f4f5 100644 --- a/src/pybamm/experiment/step/base_step.py +++ b/src/pybamm/experiment/step/base_step.py @@ -140,7 +140,7 @@ def __init__( self.value = pybamm.Interpolant( t, y, - pybamm.t - pybamm.InputParameter("start time"), + pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type] name="Drive Cycle", ) self.period = np.diff(t).min() diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index efd9874664..4dd82a2c71 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -2,7 +2,6 @@ # Binary operator classes # from __future__ import annotations -import numbers import numpy as np import numpy.typing as npt @@ -35,7 +34,7 @@ def _preprocess_binary( right = pybamm.Vector(right) # Check both left and right are pybamm Symbols - if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): + if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)): # type: ignore[redundant-expr] raise NotImplementedError( f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}" ) @@ -114,6 +113,9 @@ def __str__(self): right_str = f"{self.right!s}" return f"{left_str} {self.name} {right_str}" + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return self.__class__(self.name, left, right) # pragma: no cover + def create_copy( self, new_children: list[pybamm.Symbol] | None = None, @@ -128,7 +130,7 @@ def create_copy( children = self._children_for_copying(new_children) if not perform_simplifications: - out = self.__class__(children[0], children[1]) + out = self._new_instance(children[0], children[1]) else: # creates a new instance using the overloaded binary operator to perform # additional simplifications, rather than just calling the constructor @@ -225,6 +227,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("**", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Power(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -274,6 +279,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("+", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Addition(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) + self.right.diff(variable) @@ -301,6 +309,9 @@ def __init__( super().__init__("-", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Subtraction(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" return self.left.diff(variable) - self.right.diff(variable) @@ -330,6 +341,9 @@ def __init__( super().__init__("*", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Multiplication(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -370,6 +384,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("@", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return MatrixMultiplication(left, right) # pragma: no cover + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # We shouldn't need this @@ -419,6 +436,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("/", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Division(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply quotient rule @@ -467,6 +487,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("inner product", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Inner(left, right) # pragma: no cover + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply product rule @@ -544,6 +567,9 @@ def __init__( """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__("==", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Equality(left, right) + def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" # Equality should always be multiplied by something else so hopefully don't @@ -601,6 +627,10 @@ def __init__( ): """See :meth:`pybamm.BinaryOperator.__init__()`.""" super().__init__(name, left, right) + self.name = name + + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return _Heaviside(self.name, left, right) # pragma: no cover def diff(self, variable): """See :meth:`pybamm.Symbol.diff()`.""" @@ -679,6 +709,9 @@ def __init__( ): super().__init__("%", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Modulo(left, right) + def _diff(self, variable: pybamm.Symbol): """See :meth:`pybamm.Symbol._diff()`.""" # apply chain rule and power rule @@ -721,6 +754,9 @@ def __init__( ): super().__init__("minimum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Minimum(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"minimum({self.left!s}, {self.right!s})" @@ -765,6 +801,9 @@ def __init__( ): super().__init__("maximum", left, right) + def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol: + return Maximum(left, right) + def __str__(self): """See :meth:`pybamm.Symbol.__str__()`.""" return f"maximum({self.left!s}, {self.right!s})" @@ -1539,7 +1578,7 @@ def source( corresponding to a source term in the bulk. """ # Broadcast if left is number - if isinstance(left, numbers.Number): + if isinstance(left, (int, float)): left = pybamm.PrimaryBroadcast(left, "current collector") # force type cast for mypy diff --git a/src/pybamm/expression_tree/broadcasts.py b/src/pybamm/expression_tree/broadcasts.py index 6045c3f3e8..1fabef127c 100644 --- a/src/pybamm/expression_tree/broadcasts.py +++ b/src/pybamm/expression_tree/broadcasts.py @@ -78,8 +78,7 @@ def _from_json(cls, snippet): ) def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): - """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" - return self.__class__(child, self.broadcast_domain) + pass # pragma: no cover class PrimaryBroadcast(Broadcast): @@ -191,6 +190,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class PrimaryBroadcastToEdges(PrimaryBroadcast): """A primary broadcast onto the edges of the domain.""" @@ -321,6 +324,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" return self.orphans[0] + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class SecondaryBroadcastToEdges(SecondaryBroadcast): """A secondary broadcast onto the edges of a domain.""" @@ -438,6 +445,10 @@ def reduce_one_dimension(self): """Reduce the broadcast by one dimension.""" raise NotImplementedError + def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True): + """See :meth:`pybamm.UnaryOperator._unary_new_copy()`.""" + return self.__class__(child, self.broadcast_domain) + class TertiaryBroadcastToEdges(TertiaryBroadcast): """A tertiary broadcast onto the edges of a domain.""" @@ -463,7 +474,7 @@ def __init__( self, child_input: Numeric | pybamm.Symbol, broadcast_domain: DomainType = None, - auxiliary_domains: AuxiliaryDomainType = None, + auxiliary_domains: AuxiliaryDomainType | str = None, broadcast_domains: DomainsType = None, name: str | None = None, ): diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index dc30cf4b5e..254065cccc 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -474,7 +474,7 @@ def __init__(self, *children, name: Optional[str] = None): if name is None: # Name is the intersection of the children names (should usually make sense # if the children have been named consistently) - name = intersect(children[0].name, children[1].name) + name = intersect(children[0].name, children[1].name) or "" for child in children[2:]: name = intersect(name, child.name) if len(name) == 0: @@ -515,7 +515,7 @@ def substrings(s: str): yield s[i : j + 1] -def intersect(s1: str, s2: str): +def intersect(s1: str, s2: str) -> str: # find all the common strings between two strings all_intersects = set(substrings(s1)) & set(substrings(s2)) # intersect is the longest such intercept @@ -526,7 +526,7 @@ def intersect(s1: str, s2: str): return intersect.lstrip().rstrip() -def simplified_concatenation(*children, name: Optional[str] = None): +def simplified_concatenation(*children, name=None): """Perform simplifications on a concatenation.""" # remove children that are None children = list(filter(lambda x: x is not None, children)) diff --git a/src/pybamm/expression_tree/functions.py b/src/pybamm/expression_tree/functions.py index a5a999a092..3dbfc0f422 100644 --- a/src/pybamm/expression_tree/functions.py +++ b/src/pybamm/expression_tree/functions.py @@ -7,7 +7,7 @@ import numpy.typing as npt from scipy import special import sympy -from typing import Callable +from typing import Callable, cast from collections.abc import Sequence from typing_extensions import TypeVar @@ -33,7 +33,7 @@ class Function(pybamm.Symbol): def __init__( self, function: Callable, - *children: pybamm.Symbol, + *children: pybamm.Symbol | float | int, name: str | None = None, differentiated_function: Callable | None = None, ): @@ -43,6 +43,7 @@ def __init__( if isinstance(child, (float, int, np.number)): children[idx] = pybamm.Scalar(child) + children = cast(Sequence[pybamm.Symbol], children) if name is not None: self.name = name else: diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 0507b3304e..153a9f52f8 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Any import pybamm from datetime import datetime @@ -20,8 +21,11 @@ def __init__(self): class _SymbolEncoder(json.JSONEncoder): """Converts PyBaMM symbols into a JSON-serialisable format""" - def default(self, node: dict): - node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} + def default(self, node: dict | pybamm.Symbol): + node_dict: dict[str, Any] = { + "py/object": str(type(node))[8:-2], + "py/id": id(node), + } if isinstance(node, pybamm.Symbol): node_dict.update(node.to_json()) # this doesn't include children node_dict["children"] = [] @@ -46,7 +50,7 @@ def default(self, node: dict): class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" - def default(self, node: pybamm.Mesh): + def default(self, node: pybamm.Mesh | pybamm.SubMesh): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) @@ -61,12 +65,9 @@ def default(self, node: pybamm.Mesh): return node_dict if isinstance(node, pybamm.SubMesh): - node_dict.update(node.to_json()) + node_dict.update(node.to_json()) # type: ignore[attr-defined] return node_dict - node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover - return node_dict # pragma: no cover - class _Empty: """A dummy class to aid deserialisation""" diff --git a/src/pybamm/expression_tree/parameter.py b/src/pybamm/expression_tree/parameter.py index 14560da0b8..176a775443 100644 --- a/src/pybamm/expression_tree/parameter.py +++ b/src/pybamm/expression_tree/parameter.py @@ -5,7 +5,8 @@ import sys import numpy as np -from typing import Literal +from typing import Literal, cast +from collections.abc import Sequence import sympy @@ -97,7 +98,7 @@ class FunctionParameter(pybamm.Symbol): def __init__( self, name: str, - inputs: dict[str, pybamm.Symbol], + inputs: dict[str, pybamm.Symbol | float | int], diff_variable: pybamm.Symbol | None = None, print_name="calculate", ) -> None: @@ -110,6 +111,7 @@ def __init__( if isinstance(child, (float, int, np.number)): children_list[idx] = pybamm.Scalar(child) + children_list = cast(Sequence[pybamm.Symbol], children_list) domains = self.get_children_domains(children_list) super().__init__(name, children=children_list, domains=domains) diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index 34ca9d627b..6e87b18ac2 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -67,7 +67,9 @@ def create_object_of_size(size: int, typ="vector"): return np.nan * np.ones((size, size)) -def evaluate_for_shape_using_domain(domains: dict[str, list[str] | str], typ="vector"): +def evaluate_for_shape_using_domain( + domains: dict[str, list[str] | str] | list[str], typ="vector" +): """ Return a vector of the appropriate shape, based on the domains. Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do. @@ -965,7 +967,9 @@ def to_casadi( """ return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs) - def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol: + def _children_for_copying( + self, children: list[Symbol] | None = None + ) -> list[Symbol]: """ Gets existing children for a symbol being copied if they aren't provided. """ diff --git a/src/pybamm/expression_tree/unary_operators.py b/src/pybamm/expression_tree/unary_operators.py index 2f998c47d6..0fbc2ccfe5 100644 --- a/src/pybamm/expression_tree/unary_operators.py +++ b/src/pybamm/expression_tree/unary_operators.py @@ -9,7 +9,7 @@ import sympy import pybamm from pybamm.util import import_optional_dependency -from pybamm.type_definitions import DomainsType +from pybamm.type_definitions import DomainsType, Numeric class UnaryOperator(pybamm.Symbol): @@ -32,7 +32,7 @@ class UnaryOperator(pybamm.Symbol): def __init__( self, name: str, - child: pybamm.Symbol, + child: pybamm.Symbol | Numeric, domains: DomainsType = None, ): if isinstance(child, (float, int, np.number)): diff --git a/src/pybamm/expression_tree/variable.py b/src/pybamm/expression_tree/variable.py index 4d08686245..062f10b6df 100644 --- a/src/pybamm/expression_tree/variable.py +++ b/src/pybamm/expression_tree/variable.py @@ -61,12 +61,12 @@ def __init__( domains: DomainsType = None, bounds: tuple[pybamm.Symbol] | None = None, print_name: str | None = None, - scale: float | pybamm.Symbol | None = 1, - reference: float | pybamm.Symbol | None = 0, + scale: float | int | pybamm.Symbol | None = 1, + reference: float | int | pybamm.Symbol | None = 0, ): - if isinstance(scale, numbers.Number): + if isinstance(scale, (float, int)): scale = pybamm.Scalar(scale) - if isinstance(reference, numbers.Number): + if isinstance(reference, (float, int)): reference = pybamm.Scalar(reference) self._scale = scale self._reference = reference @@ -88,7 +88,7 @@ def bounds(self): return self._bounds @bounds.setter - def bounds(self, values: tuple[Numeric, Numeric]): + def bounds(self, values: tuple[Numeric, Numeric] | None): if values is None: values = (-np.inf, np.inf) else: diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index b5670320b4..c8258869f9 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"): self.use_jacobian = True self.convert_to_format = "casadi" + self.calculate_sensitivities = [] + # Model is not initially discretised self.is_discretised = False self.y_slices = None diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index ee146a2002..6a4fce9a04 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -126,7 +126,7 @@ def __init__( # Set colors, linestyles, figsize, axis limits # call LoopList to make sure list index never runs out if colors is None: - self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"]) + self.colors = LoopList(["r", "b", "k", "g", "m", "c"]) else: self.colors = LoopList(colors) self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."]) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 49e9b928ae..3c4014323b 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -94,8 +94,8 @@ def supports_parallel_solve(self): def requires_explicit_sensitivities(self): return True - @root_method.setter - def root_method(self, method): + @root_method.setter # type: ignore[attr-defined, no-redef] + def root_method(self, method) -> None: if method == "casadi": method = pybamm.CasadiAlgebraicSolver(self.root_tol) elif isinstance(method, str): @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from( """ ninputs = len(model.calculate_sensitivities) - initial_conditions = tuple([] for _ in range(ninputs)) + initial_conditions: tuple = tuple([] for _ in range(ninputs)) solution = solution.last_state for var in model.initial_conditions: final_state = solution[var.name] @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from( slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()] # sort equations according to slices - concatenated_initial_conditions = [ + concatenated_initial_conditions = tuple( casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))]) for init in initial_conditions - ] + ) return concatenated_initial_conditions def process_t_interp(self, t_interp): diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ef505570fa..2d5dae7dc7 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: npt.NDArray = None, + t: Union[npt.NDArray, None] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -292,7 +292,7 @@ def jax_value( def jax_grad( self, - t: npt.NDArray = None, + t: Union[npt.NDArray, None] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -465,13 +465,11 @@ def _jax_vjp_impl( logger.debug(f" py:invar: {type(invar)}, {invar}") logger.debug(f" py:primals: {type(primals)}, {primals}") - t = primals[0] + t = np.asarray(primals[0]) inputs = primals[1:] if isinstance(invar, float): invar = round(invar) - if isinstance(t, float): - t = np.array(t) if t.ndim == 0 or (t.ndim == 1 and t.shape[0] == 1): # scalar time input diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index ce41c1796e..077b079775 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -7,7 +7,7 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: npt.NDArray + initial_condition: Union[npt.NDArray, float] discrete_times: Optional[npt.NDArray] @staticmethod diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index 4f17c60d94..d96a667344 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -160,7 +160,7 @@ def __init__( def has_sensitivities(self) -> bool: if isinstance(self._all_sensitivities, bool): return self._all_sensitivities - elif isinstance(self._all_sensitivities, dict): + else: return len(self._all_sensitivities) > 0 def extract_explicit_sensitivities(self): diff --git a/src/pybamm/solvers/summary_variable.py b/src/pybamm/solvers/summary_variable.py index 4c3da92a42..2594ce3f67 100644 --- a/src/pybamm/solvers/summary_variable.py +++ b/src/pybamm/solvers/summary_variable.py @@ -4,7 +4,7 @@ from __future__ import annotations import pybamm import numpy as np -from typing import Any +from typing import Any, cast class SummaryVariables: @@ -40,12 +40,15 @@ def __init__( ): self.user_inputs = user_inputs or {} self.esoh_solver = esoh_solver - self._variables = {} # Store computed variables - self.cycle_number = np.array([]) + # Store computed variables + self._variables: dict[str, float | list[float]] = {} + self.cycle_number = np.array([]) + self.cycles: list[SummaryVariables] | None = None + self._all_variables: list[str] | None = None model = solution.all_models[0] self._possible_variables = model.summary_variables # minus esoh variables - self._esoh_variables = None # Store eSOH variable names + self._esoh_variables: list[str] | None = None # Store eSOH variable names # Flag if eSOH calculations are needed self.calc_esoh = ( @@ -69,7 +72,7 @@ def _initialize_for_cycles(self, cycle_summary_variables: list[SummaryVariables] self.first_state = None self.last_state = None self.cycles = cycle_summary_variables - self.cycle_number = np.arange(1, len(self.cycles) + 1) + self.cycle_number = np.arange(1, len(self.cycles) + 1, dtype=float) first_cycle = self.cycles[0] self.calc_esoh = first_cycle.calc_esoh self.esoh_solver = first_cycle.esoh_solver @@ -81,25 +84,27 @@ def all_variables(self) -> list[str]: Return names of all possible summary variables, including eSOH variables if appropriate. """ - try: + if self._all_variables is not None: return self._all_variables - except AttributeError: - base_vars = self._possible_variables.copy() - base_vars.extend( - f"Change in {var[0].lower() + var[1:]}" - for var in self._possible_variables - ) + base_vars = self._possible_variables.copy() + base_vars.extend( + f"Change in {var[0].lower() + var[1:]}" for var in self._possible_variables + ) - if self.calc_esoh: - base_vars.extend(self.esoh_variables) + if self.calc_esoh: + base_vars.extend(self.esoh_variables) - self._all_variables = base_vars - return self._all_variables + self._all_variables = cast(list[str], base_vars) + return self._all_variables @property def esoh_variables(self) -> list[str] | None: """Return names of all eSOH variables.""" - if self.calc_esoh and self._esoh_variables is None: + if ( + self.esoh_solver is not None + and self.calc_esoh + and self._esoh_variables is None + ): esoh_model = self.esoh_solver._get_electrode_soh_sims_full().model esoh_vars = list(esoh_model.variables.keys()) self._esoh_variables = esoh_vars @@ -123,7 +128,7 @@ def __getitem__(self, key: str) -> float | list[float]: # return it if it exists return self._variables[key] elif key == "Cycle number": - return self.cycle_number + return cast(list[float], self.cycle_number.tolist()) elif key not in self.all_variables: # check it's listed as a summary variable raise KeyError(f"Variable '{key}' is not a summary variable.") @@ -148,10 +153,11 @@ def update(self, var: str): def _update_multiple_cycles(self, var: str, var_lowercase: str): """Creates aggregated summary variables for where more than one cycle exists.""" - var_cycle = [cycle[var] for cycle in self.cycles] - change_var_cycle = [ - cycle[f"Change in {var_lowercase}"] for cycle in self.cycles - ] + cycles = cast(list[SummaryVariables], self.cycles) + var_cycle = cast(list[float], [cycle[var] for cycle in cycles]) + change_var_cycle = cast( + list[float], [cycle[f"Change in {var_lowercase}"] for cycle in cycles] + ) self._variables[var] = var_cycle self._variables[f"Change in {var_lowercase}"] = change_var_cycle @@ -180,8 +186,9 @@ def _get_esoh_variables(self) -> dict[str, float]: Q_p = self.last_state["Positive electrode capacity [A.h]"].data[0] Q_Li = self.last_state["Total lithium capacity in particles [A.h]"].data[0] all_inputs = {**self.user_inputs, "Q_n": Q_n, "Q_p": Q_p, "Q_Li": Q_Li} + esoh_solver = cast(pybamm.lithium_ion.ElectrodeSOHSolver, self.esoh_solver) try: - esoh_sol = self.esoh_solver.solve(inputs=all_inputs) + esoh_sol = esoh_solver.solve(inputs=all_inputs) except pybamm.SolverError as error: # pragma: no cover raise pybamm.SolverError( "Could not solve for eSOH summary variables" diff --git a/src/pybamm/telemetry.py b/src/pybamm/telemetry.py index 3825738d47..ac5103139b 100644 --- a/src/pybamm/telemetry.py +++ b/src/pybamm/telemetry.py @@ -1,3 +1,4 @@ +from typing import cast from posthog import Posthog import pybamm import sys @@ -20,7 +21,7 @@ def capture(**kwargs): # pragma: no cover project_api_key="phc_acTt7KxmvBsAxaE0NyRd5WfJyNxGvBq1U9HnlQSztmb", host="https://us.i.posthog.com", ) - _posthog.log.setLevel("CRITICAL") + cast(Posthog, _posthog).log.setLevel("CRITICAL") def disable(): diff --git a/src/pybamm/util.py b/src/pybamm/util.py index dcab37b0dc..b1a7c0dd80 100644 --- a/src/pybamm/util.py +++ b/src/pybamm/util.py @@ -152,7 +152,7 @@ def search( Default is 0.4 """ - if not isinstance(keys, (str, list)) or not all( + if not isinstance(keys, (str, list)) or not all( # type: ignore[redundant-expr] isinstance(k, str) for k in keys ): msg = f"'keys' must be a string or a list of strings, got {type(keys)}" diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index eba3ca1bbd..ceed90fda2 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -11,7 +11,7 @@ import pybamm import sympy -EMPTY_DOMAINS = { +EMPTY_DOMAINS: dict[str, list[str]] = { "primary": [], "secondary": [], "tertiary": [], diff --git a/tests/unit/test_expression_tree/test_operations/test_copy.py b/tests/unit/test_expression_tree/test_operations/test_copy.py index f0d59a1fe1..67a884771c 100644 --- a/tests/unit/test_expression_tree/test_operations/test_copy.py +++ b/tests/unit/test_expression_tree/test_operations/test_copy.py @@ -79,6 +79,7 @@ def test_symbol_create_copy_new_children(self): a * b, a / b, a**b, + b % a, pybamm.minimum(a, b), pybamm.maximum(a, b), pybamm.Equality(a, b), @@ -89,12 +90,15 @@ def test_symbol_create_copy_new_children(self): b * a, b / a, b**a, + b % a, pybamm.minimum(b, a), pybamm.maximum(b, a), pybamm.Equality(b, a), ], ): - new_symbol = symbol_ab.create_copy(new_children=[b, a]) + new_symbol = symbol_ab.create_copy( + new_children=[b, a], perform_simplifications=False + ) assert new_symbol == symbol_ba assert new_symbol.print_name == symbol_ba.print_name diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 78a8dca58e..4070348d70 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -255,7 +255,7 @@ def test_copy_with_computed_variables(self): assert ( sol1._variables[k] == sol2._variables[k] for k in sol1._variables.keys() - ) + ) is not None assert sol2.variables_returned is True def test_last_state(self):