Skip to content

Commit 76b01ea

Browse files
committed
Mypy fixes ones typing has been activated
1 parent f4b3543 commit 76b01ea

5 files changed

Lines changed: 39 additions & 26 deletions

File tree

ufl/algorithms/apply_restrictions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ufl.corealg.map_dag import map_expr_dag
2020
from ufl.corealg.multifunction import MultiFunction
2121
from ufl.domain import Mesh, extract_unique_domain
22+
from ufl.integral import Integral
2223
from ufl.sobolevspace import H1
2324

2425
default_restriction_map = {
@@ -291,7 +292,9 @@ def facet_normal(self, o):
291292
return self._require_restriction(o)
292293

293294

294-
def apply_restrictions(expression: Expr, default_restrictions: dict | None = None) -> Expr:
295+
def apply_restrictions(
296+
expression: Expr | Integral, default_restrictions: dict | None = None
297+
) -> Expr:
295298
"""Propagate restriction nodes to wrap differential terminals directly.
296299
297300
Args:

ufl/algorithms/domain_analysis.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
import numbers
1212
from collections import defaultdict
13-
from typing import Literal
13+
from typing import TYPE_CHECKING, Literal
14+
15+
if TYPE_CHECKING:
16+
from ufl.classes import Coefficient
1417

1518
import ufl
1619
from ufl.algorithms.coordinate_derivative_helpers import (
@@ -82,8 +85,8 @@ def __init__(
8285

8386
# This is populated in preprocess using data not available at
8487
# this stage:
85-
self.integral_coefficients = None
86-
self.enabled_coefficients = None
88+
self.integral_coefficients: set[Coefficient] | None = None
89+
self.enabled_coefficients: list[bool] | None = None
8790
non_primal_domains = tuple(domain_integral_type_map.keys())[1:]
8891
if sort_domains(non_primal_domains) != non_primal_domains:
8992
raise ValueError("domain_integral_type_map must have been sorted by domains")
@@ -336,14 +339,14 @@ def keyfunc(item):
336339
)
337340

338341
integral_datas = []
339-
for (d, itype, sid, extra_d_itype_tuple), integrals in sorted(itgs.items(), key=keyfunc):
342+
for (d, itype, sid, extra_d_itype_tuple), _itgs in sorted(itgs.items(), key=keyfunc):
340343
d_itype_tuple = ((d, itype),) + extra_d_itype_tuple
341344
integral_datas.append(
342345
IntegralData(
343346
d,
344347
itype,
345348
sid,
346-
integrals,
349+
_itgs,
347350
{},
348351
dict(d_itype_tuple),
349352
)

ufl/algorithms/formdata.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
from ufl.algorithms.domain_analysis import IntegralData, reconstruct_form_from_integral_data
2121
from ufl.algorithms.replace import replace
2222
from ufl.classes import Argument, Coefficient, FunctionSpace, GeometricFacetQuantity
23-
from ufl.coefficient import BaseCoefficient
2423
from ufl.corealg.traversal import traverse_unique_terminals
25-
from ufl.domain import MeshSequence, extract_domains, extract_unique_domain
24+
from ufl.domain import AbstractDomain, MeshSequence, extract_domains, extract_unique_domain
2625
from ufl.form import Form, Zero
2726
from ufl.utils.formatting import estr, lstr, tstr
2827
from ufl.utils.sequences import max_degree
@@ -99,7 +98,7 @@ def _compute_element_mapping(form: Form):
9998

10099
def _compute_max_subdomain_ids(integral_data: list[IntegralData]) -> dict[str, int]:
101100
"""Compute the maximum subdomain ids."""
102-
max_subdomain_ids = {}
101+
max_subdomain_ids: dict[str, int] = {}
103102
for itg_data in integral_data:
104103
it = itg_data.integral_type
105104
for integral in itg_data.integrals:
@@ -141,8 +140,8 @@ def _check_facet_geometry(integral_data):
141140

142141

143142
def _build_coefficient_replace_map(
144-
coefficients: list[BaseCoefficient], element_mapping=None
145-
) -> tuple[list[BaseCoefficient], dict[BaseCoefficient, BaseCoefficient]]:
143+
coefficients: list[Coefficient], element_mapping=None
144+
) -> tuple[list[Coefficient], dict[Coefficient, Coefficient]]:
146145
"""Create new Coefficient objects with count starting at 0.
147146
148147
Returns:
@@ -175,10 +174,10 @@ class FormData:
175174

176175
_original_form: Form
177176
_integral_data: list[IntegralData]
178-
_reduced_coefficients: list[BaseCoefficient]
179-
_original_coefficient_positions: list[tuple[int, BaseCoefficient]]
180-
_function_replace_map: dict[BaseCoefficient, BaseCoefficient]
181-
_coefficient_elements: list[Any]
177+
_reduced_coefficients: list[Coefficient]
178+
_original_coefficient_positions: list[int]
179+
_function_replace_map: dict[Coefficient, Coefficient]
180+
_coefficient_elements: tuple[Any, ...]
182181
_coefficient_split: dict[Coefficient, list[Coefficient]]
183182

184183
def __init__(
@@ -188,7 +187,7 @@ def __init__(
188187
do_apply_default_restrictions: bool = True,
189188
do_apply_restrictions: bool = True,
190189
do_replace_functions: bool = False,
191-
coefficients_to_split: tuple[BaseCoefficient, ...] | None = None,
190+
coefficients_to_split: tuple[Coefficient, ...] | None = None,
192191
complex_mode: bool = False,
193192
):
194193
"""Create form-data for a form that has been processed.
@@ -223,8 +222,9 @@ def __init__(
223222
# Figure out which coefficients from the original form are
224223
# actually used in any integral (Differentiation may reduce the
225224
# set of coefficients w.r.t. the original form)
226-
reduced_coefficients_set = set()
225+
reduced_coefficients_set: set[Coefficient] = set()
227226
for itg_data in self.integral_data:
227+
assert itg_data.integral_coefficients is not None
228228
reduced_coefficients_set.update(itg_data.integral_coefficients)
229229
self._reduced_coefficients = sorted(reduced_coefficients_set, key=lambda c: c.count())
230230
self._original_coefficient_positions = [
@@ -236,6 +236,7 @@ def __init__(
236236
# Store back into integral data which form coefficients are used
237237
# by each integral
238238
for itg_data in self.integral_data:
239+
assert itg_data.integral_coefficients is not None
239240
itg_data.enabled_coefficients = [
240241
bool(coeff in itg_data.integral_coefficients) for coeff in self.reduced_coefficients
241242
]
@@ -285,7 +286,7 @@ def __init__(
285286
elem = c.ufl_element()
286287
coefficient_split[c] = [
287288
Coefficient(FunctionSpace(m, e))
288-
for m, e in zip(mesh.iterable_like(elem), elem.sub_elements)
289+
for m, e in zip(mesh.iterable_like(elem), elem.sub_elements) # type: ignore
289290
]
290291
self._coefficient_split = coefficient_split
291292
coeff_splitter = CoefficientSplitter(self.coefficient_split)
@@ -309,6 +310,7 @@ def __init__(
309310
for _, integral_type in itg_data.domain_integral_type_map.items()
310311
):
311312
continue
313+
default_restrictions: dict[AbstractDomain, str | None] | None
312314
if do_apply_default_restrictions:
313315
default_restrictions = {
314316
domain: default_restriction_map[integral_type]
@@ -412,12 +414,12 @@ def max_subdomain_ids(self) -> dict[str, int]:
412414
return _compute_max_subdomain_ids(self.integral_data)
413415

414416
@cached_property
415-
def argument_elements(self) -> list[Any]:
417+
def argument_elements(self) -> tuple[Any]:
416418
"""The set of elements the arguments in the form."""
417419
return tuple(f.ufl_element() for f in self.original_form.arguments())
418420

419421
@property
420-
def coefficient_elements(self) -> list[Any]:
422+
def coefficient_elements(self) -> tuple[Any]:
421423
"""The set of elements used for coefficients in the form."""
422424
return self._coefficient_elements
423425

ufl/domain.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from ufl.core.expr import Expr
1717
from ufl.finiteelement import AbstractFiniteElement # To avoid cyclic import when type-hinting.
1818
from ufl.form import Form
19+
from ufl.integral import Integral
20+
1921
from ufl.cell import AbstractCell, CellSequence
2022
from ufl.core.ufl_id import attach_ufl_id
2123
from ufl.core.ufl_type import UFLObject
@@ -339,7 +341,7 @@ def as_domain(domain):
339341
return domain
340342

341343

342-
def sort_domains(domains: Sequence[AbstractDomain]) -> tuple[AbstractDomain, ...]:
344+
def sort_domains(domains: Iterable[AbstractDomain]) -> tuple[AbstractDomain, ...]:
343345
"""Sort domains in a canonical ordering.
344346
345347
Args:
@@ -388,7 +390,7 @@ def join_domains(domains: Sequence[AbstractDomain], expand_mesh_sequence: bool =
388390

389391

390392
def extract_domains(
391-
expr: Expr | Form,
393+
expr: Expr | Form | Integral,
392394
expand_mesh_sequence: bool = True,
393395
) -> tuple[AbstractDomain, ...]:
394396
"""Return all domains expression is defined on.

ufl/form.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,12 @@ def _sorted_integrals(integrals: typing.Iterable[Integral]) -> tuple[Integral, .
4343
"""
4444
# Group integrals in multilevel dict by keys
4545
# [domain][integral_type][subdomain_id]
46-
integrals_dict = defaultdict(
47-
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
48-
)
46+
from ufl.classes import AbstractDomain
47+
48+
integrals_dict: dict[
49+
AbstractDomain,
50+
dict[str, dict[tuple[typing.Any, ...], dict[int | tuple[int, ...], list[Integral]]]],
51+
] = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
4952
for integral in integrals:
5053
d = integral.ufl_domain()
5154
if d is None:
@@ -314,7 +317,7 @@ def __init__(self, integrals: list[Integral]):
314317
self._signature = None
315318

316319
# Never use this internally in ufl!
317-
self._cache = {}
320+
self._cache: dict[typing.Any, typing.Any] = {}
318321

319322
# --- Accessor interface ---
320323

0 commit comments

Comments
 (0)