diff --git a/FIAT/discontinuous_lagrange.py b/FIAT/discontinuous_lagrange.py index e086421e5..e2f986dad 100644 --- a/FIAT/discontinuous_lagrange.py +++ b/FIAT/discontinuous_lagrange.py @@ -6,8 +6,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools -import math import numpy as np +from math import factorial from FIAT import finite_element, polynomial_set, dual_set, functional, P0 from FIAT.reference_element import LINE, make_lattice from FIAT.orientation_utils import make_entity_permutations_simplex @@ -18,7 +18,7 @@ def make_entity_permutations(dim, npoints): if npoints <= 0: - return {o: [] for o in range(math.factorial(dim + 1))} + return {o: [] for o in range(factorial(dim + 1))} # DG nodes are numbered, in order of significance, # - by g0: entity dim (vertices first, then edges, then ...) # - by g1: entity ids (DoFs on entities of smaller ids first) diff --git a/FIAT/dual_set.py b/FIAT/dual_set.py index 3d9cacacc..22f112e8a 100644 --- a/FIAT/dual_set.py +++ b/FIAT/dual_set.py @@ -177,7 +177,7 @@ def to_riesz(self, poly_set): for pt, wac_list in ell.deriv_dict.items(): j = dpts.index(pt) for (w, alpha, c) in wac_list: - dwts[alpha][i][c][j] = w + dwts[alpha][(i, *c, j)] = w for alpha in dwts: mat[ells] += numpy.dot(dwts[alpha], dexpansion_values[alpha].T) return mat diff --git a/FIAT/orientation_utils.py b/FIAT/orientation_utils.py index 92c678991..5354df8c1 100644 --- a/FIAT/orientation_utils.py +++ b/FIAT/orientation_utils.py @@ -1,7 +1,7 @@ import itertools -import math from collections.abc import Sequence import numpy as np +from math import factorial def make_entity_permutations_simplex(dim, npoints): @@ -54,7 +54,7 @@ def make_entity_permutations_simplex(dim, npoints): from FIAT.polynomial_set import mis if npoints <= 0: - return {o: [] for o in range(math.factorial(dim + 1))} + return {o: [] for o in range(factorial(dim + 1))} a = np.array(sorted(mis(dim + 1, npoints - 1)), dtype=int)[:, ::-1] index_perms = sorted(itertools.permutations(range(dim + 1))) perms = {} diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index 379b2958a..c7a679d4c 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -26,7 +26,7 @@ from math import factorial import numpy -from recursivenodes.nodes import _decode_family, _recursive +from recursivenodes import recursive_nodes from FIAT.orientation_utils import ( make_cell_orientation_reflection_map_simplex, @@ -41,20 +41,6 @@ TENSORPRODUCT = 99 -def multiindex_equal(d, isum, imin=0): - """A generator for d-tuple multi-indices whose sum is isum and minimum is imin. - """ - if d <= 0: - return - imax = isum - (d - 1) * imin - if imax < imin: - return - for i in range(imin, imax): - for a in multiindex_equal(d - 1, isum - i, imin=imin): - yield a + (i,) - yield (imin,) * (d - 1) + (imax,) - - def lattice_iter(start, finish, depth): """Generator iterating over the depth-dimensional lattice of integers between start and (finish-1). This works on simplices in @@ -85,11 +71,11 @@ def make_lattice(verts, n, interior=0, variant=None): "equispaced_interior": "equi_interior", "gll": "lgl"} family = recursivenodes_families.get(variant, variant) - family = _decode_family(family) - D = len(verts) - X = numpy.array(verts) - get_point = lambda alpha: tuple(numpy.dot(_recursive(D - 1, n, alpha, family), X)) - return list(map(get_point, multiindex_equal(D, n, interior))) + D = len(verts) - 1 + X = numpy.asarray(verts[::-1]) + bary = recursive_nodes(D, n, family=family, interior=interior) + pts = numpy.dot(bary, X) + return list(map(tuple, pts)) def linalg_subspace_intersection(A, B): diff --git a/test/unit/test_kong_mulder_veldhuizen.py b/test/unit/test_kong_mulder_veldhuizen.py index 323080f76..40d09e7bf 100644 --- a/test/unit/test_kong_mulder_veldhuizen.py +++ b/test/unit/test_kong_mulder_veldhuizen.py @@ -1,6 +1,6 @@ -import math import numpy as np import pytest +from math import factorial as fct from FIAT.reference_element import UFCInterval, UFCTriangle, UFCTetrahedron from FIAT import create_quadrature, make_quadrature, polynomial_set @@ -13,7 +13,6 @@ @pytest.mark.parametrize("p_d", [(1, 1), (2, 3), (3, 4)]) def test_kmv_quad_tet_schemes(p_d): # noqa: W503 - fct = math.factorial p, d = p_d q = create_quadrature(Te, p, "KMV") for i in range(d + 1): @@ -31,7 +30,6 @@ def test_kmv_quad_tet_schemes(p_d): # noqa: W503 @pytest.mark.parametrize("p_d", [(1, 1), (2, 3), (3, 5), (4, 7), (5, 9)]) def test_kmv_quad_tri_schemes(p_d): - fct = math.factorial p, d = p_d q = create_quadrature(T, p, "KMV") for i in range(d + 1):