Skip to content
3 changes: 2 additions & 1 deletion examples/scripts/SPM_compare_particle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Compare different discretisations in the particle
#
import argparse
from typing import Any
import numpy as np
import pybamm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -48,7 +49,7 @@
disc.process_model(model)

# solve model
solutions = [None] * len(models)
solutions: list[Any] = [None] * len(models)
t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solutions[i] = model.default_solver.solve(model, t_eval)
Expand Down
21 changes: 11 additions & 10 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
time += dt

# plot
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
if step_solution is not None:
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
2 changes: 1 addition & 1 deletion examples/scripts/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def T_exact(x, t):
# Plot ------------------------------------------------------------------------
x_nodes = mesh["rod"].nodes # numerical gridpoints
xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution
plot_times = np.linspace(0, 1, 5)
plot_times: np.ndarray = np.linspace(0, 1, 5)

plt.figure(figsize=(15, 8))
cmap = plt.get_cmap("inferno")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/minimal_example_of_lookup_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def process_2D(name, data):
D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df)


def D_s_n(sto, T):
def D_s_n_func(sto, T):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because their was a variable in this file with name D_s_n so mypy gave error function and variable both have same name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the motive behind redefining the variable here was to show how one can pass both a constant value and a function as a parameter value. I would say we should revert this and not run mypy on examples at all (see my comment below).

name, (x, y) = D_s_n_data
return pybamm.Interpolant(x, y, [T, sto], name)


parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n
parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func

k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"]

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = "^(build/|docs/conf\\.py)$"

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the type of this expression? Isn't it pybamm.Symbol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup you're right but it is is not a Sequence[Symbol] or a Time, so mypy throws error that's why I've added it inside a list

name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
49 changes: 44 additions & 5 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import sympy
Expand Down Expand Up @@ -33,8 +32,8 @@ def _preprocess_binary(
raise ValueError("right must be a 1D array")
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
# Check right is pybamm Symbol
if not isinstance(right, pybamm.Symbol):
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -113,6 +112,9 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some information on why this was added?

Copy link
Member Author

@Rishab87 Rishab87 Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced self.__ class __ with new_instance because earlier when we were using self.__ class __ it showed a third arg was not getting passed:

error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]

but this function was always getting called from instance of its child classes which don't need to pass 3 arguments, so i thought it was better to make a new_instance method which can be overrided in child classes

I've already added this in the PR description of previous sp-check-guidelines PR, should I add it here too? Or follow some different approach

return self.__class__(self.name, left, right) # pragma: no cover

def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
Expand All @@ -127,7 +129,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self._new_instance(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -224,6 +226,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("**", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Power(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -273,6 +278,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("+", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Addition(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) + self.right.diff(variable)
Expand Down Expand Up @@ -300,6 +308,9 @@ def __init__(

super().__init__("-", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Subtraction(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) - self.right.diff(variable)
Expand Down Expand Up @@ -329,6 +340,9 @@ def __init__(

super().__init__("*", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Multiplication(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -369,6 +383,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("@", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return MatrixMultiplication(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# We shouldn't need this
Expand Down Expand Up @@ -418,6 +435,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("/", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Division(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply quotient rule
Expand Down Expand Up @@ -466,6 +486,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("inner product", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Inner(left, right) # pragma: no cover

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -543,6 +566,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("==", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Equality(left, right)

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Equality should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -600,6 +626,10 @@ def __init__(
):
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__(name, left, right)
self.name = name

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return _Heaviside(self.name, left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
Expand Down Expand Up @@ -678,6 +708,9 @@ def __init__(
):
super().__init__("%", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Modulo(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -720,6 +753,9 @@ def __init__(
):
super().__init__("minimum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Minimum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"minimum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -764,6 +800,9 @@ def __init__(
):
super().__init__("maximum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Maximum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"maximum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -1538,7 +1577,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
17 changes: 14 additions & 3 deletions src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _from_json(cls, snippet):
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)
pass # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In src/pybamm/expression_tree/broadcasts.py I've forced child classes of Broadcast to implement _unary_new_copy because earlier we were using self.broadcast_domain in this function in Broadcast class but it does not have any attribute self.broadcast_domain, this function was only getting called by instance of their child classes which has self.broadcast_domain property.



class PrimaryBroadcast(Broadcast):
Expand Down Expand Up @@ -191,6 +190,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these functions being repeated?

Copy link
Member Author

@Rishab87 Rishab87 Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above now it needs to be overrided in child classes

"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"""A primary broadcast onto the edges of the domain."""
Expand Down Expand Up @@ -321,6 +324,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"""A secondary broadcast onto the edges of a domain."""
Expand Down Expand Up @@ -438,6 +445,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class TertiaryBroadcastToEdges(TertiaryBroadcast):
"""A tertiary broadcast onto the edges of a domain."""
Expand All @@ -463,7 +474,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def substrings(s: str):
yield s[i : j + 1]


def intersect(s1: str, s2: str):
def intersect(s1: str, s2: str) -> str:
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
Expand All @@ -525,7 +525,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children, name: Optional[str] = None):
def simplified_concatenation(*children, name=None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,9 @@ def to_casadi(
"""
return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs)

def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol:
def _children_for_copying(
self, children: list[Symbol] | None = None
) -> list[Symbol]:
"""
Gets existing children for a symbol being copied if they aren't provided.
"""
Expand Down
2 changes: 2 additions & 0 deletions src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"):
self.use_jacobian = True
self.convert_to_format = "casadi"

self.calculate_sensitivities = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this variable coming from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In base solver got this error:

src/pybamm/solvers/base_solver.py:1124: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]

that's why I added this


# Model is not initially discretised
self.is_discretised = False
self.y_slices = None
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
# Set colors, linestyles, figsize, axis limits
# call LoopList to make sure list index never runs out
if colors is None:
self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"])
self.colors = LoopList(["r", "b", "k", "g", "m", "c"])
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
Expand Down
10 changes: 5 additions & 5 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def supports_parallel_solve(self):
def requires_explicit_sensitivities(self):
return True

@root_method.setter
def root_method(self, method):
@root_method.setter # type: ignore[attr-defined, no-redef]
def root_method(self, method) -> None:
if method == "casadi":
method = pybamm.CasadiAlgebraicSolver(self.root_tol)
elif isinstance(method, str):
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from(
"""

ninputs = len(model.calculate_sensitivities)
initial_conditions = tuple([] for _ in range(ninputs))
initial_conditions: tuple = tuple([] for _ in range(ninputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the explicit tuple?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because mypy was unable to infer it:

src/pybamm/solvers/base_solver.py:1125: error: Need type annotation for "initial_conditions"  [var-annotated]

I think mypy has difficulty inferring that the tuple contains lists of a particular type

solution = solution.last_state
for var in model.initial_conditions:
final_state = solution[var.name]
Expand All @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from(
slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()]

# sort equations according to slices
concatenated_initial_conditions = [
concatenated_initial_conditions = tuple(
casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))])
for init in initial_conditions
]
)
return concatenated_initial_conditions

def process_t_interp(self, t_interp):
Expand Down
Loading
Loading