Skip to content

Commit 4cff9b4

Browse files
committed
Revert __getitem__
1 parent e8ff24b commit 4cff9b4

File tree

4 files changed

+29
-34
lines changed

4 files changed

+29
-34
lines changed

finat/physically_mapped.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,15 +270,15 @@ def basis_evaluation(self, order, ps, entity=None, coordinate_mapping=None):
270270
M = self.basis_transformation(coordinate_mapping)
271271
# we expect M to be sparse with O(1) nonzeros per row
272272
# for each row, get the column index of each nonzero entry
273-
csr = [[j for j in range(M.shape[1]) if not isinstance(M[i, j], gem.Zero)]
273+
csr = [[j for j in range(M.shape[1]) if not isinstance(M.array[i, j], gem.Zero)]
274274
for i in range(M.shape[0])]
275275

276276
def matvec(table):
277277
# basis recombination using hand-rolled sparse-dense matrix multiplication
278278
ii = gem.indices(len(table.shape)-1)
279279
phi = [gem.Indexed(table, (j, *ii)) for j in range(M.shape[1])]
280280
# the sum approach is faster than calling numpy.dot or gem.IndexSum
281-
expressions = [gem.ComponentTensor(sum(M[i, j] * phi[j] for j in js), ii)
281+
expressions = [gem.ComponentTensor(sum(M.array[i, j] * phi[j] for j in js), ii)
282282
for i, js in enumerate(csr)]
283283
val = gem.ListTensor(expressions)
284284
# val = M @ table

gem/gem.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __call__(self, *args, **kwargs):
5454

5555
# Set free_indices if not set already
5656
if not hasattr(obj, 'free_indices'):
57-
obj.free_indices = unique(chain(*[c.free_indices
58-
for c in obj.children]))
57+
obj.free_indices = unique(chain.from_iterable(c.free_indices
58+
for c in obj.children))
5959
# Set dtype if not set already.
6060
if not hasattr(obj, 'dtype'):
6161
obj.dtype = obj.inherit_dtype_from_children(obj.children)
@@ -306,9 +306,6 @@ def value(self):
306306
def shape(self):
307307
return self.array.shape
308308

309-
def __getitem__(self, i):
310-
return self.array[i]
311-
312309

313310
class Variable(Terminal):
314311
"""Symbolic variable tensor"""
@@ -337,7 +334,7 @@ def __new__(cls, a, b):
337334
return a
338335

339336
if isinstance(a, Constant) and isinstance(b, Constant):
340-
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children([a, b]))
337+
return Literal(a.value + b.value, dtype=Node.inherit_dtype_from_children((a, b)))
341338

342339
self = super(Sum, cls).__new__(cls)
343340
self.children = a, b
@@ -361,7 +358,7 @@ def __new__(cls, a, b):
361358
return a
362359

363360
if isinstance(a, Constant) and isinstance(b, Constant):
364-
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children([a, b]))
361+
return Literal(a.value * b.value, dtype=Node.inherit_dtype_from_children((a, b)))
365362

366363
self = super(Product, cls).__new__(cls)
367364
self.children = a, b
@@ -385,7 +382,7 @@ def __new__(cls, a, b):
385382
return a
386383

387384
if isinstance(a, Constant) and isinstance(b, Constant):
388-
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children([a, b]))
385+
return Literal(a.value / b.value, dtype=Node.inherit_dtype_from_children((a, b)))
389386

390387
self = super(Division, cls).__new__(cls)
391388
self.children = a, b
@@ -676,6 +673,19 @@ def __new__(cls, aggregate, multiindex):
676673
if isinstance(aggregate, Zero):
677674
return Zero(dtype=aggregate.dtype)
678675

676+
# Simplify Literal and ListTensor
677+
if isinstance(aggregate, (Constant, ListTensor)):
678+
if all(isinstance(i, int) for i in multiindex):
679+
# All indices fixed
680+
sub = aggregate.array[multiindex]
681+
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
682+
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
683+
# Some indices fixed
684+
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
685+
sub = aggregate.array[slices]
686+
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
687+
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))
688+
679689
# Simplify Indexed(ComponentTensor(Indexed(C, kk), jj), ii) -> Indexed(C, ll)
680690
if isinstance(aggregate, ComponentTensor):
681691
B, = aggregate.children
@@ -689,19 +699,6 @@ def __new__(cls, aggregate, multiindex):
689699
ll = tuple(rep.get(k, k) for k in kk)
690700
return Indexed(C, ll)
691701

692-
# Simplify Literal and ListTensor
693-
if isinstance(aggregate, (Constant, ListTensor)):
694-
if all(isinstance(i, int) for i in multiindex):
695-
# All indices fixed
696-
sub = aggregate[multiindex]
697-
return Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else sub
698-
elif any(isinstance(i, int) for i in multiindex) and all(isinstance(i, (int, Index)) for i in multiindex):
699-
# Some indices fixed
700-
slices = tuple(i if isinstance(i, int) else slice(None) for i in multiindex)
701-
sub = aggregate[slices]
702-
sub = Literal(sub, dtype=aggregate.dtype) if isinstance(aggregate, Constant) else ListTensor(sub)
703-
return Indexed(sub, tuple(i for i in multiindex if not isinstance(i, int)))
704-
705702
self = super(Indexed, cls).__new__(cls)
706703
self.children = (aggregate,)
707704
self.multiindex = multiindex
@@ -945,9 +942,6 @@ def shape(self):
945942
def __reduce__(self):
946943
return type(self), (self.array,)
947944

948-
def __getitem__(self, i):
949-
return self.array[i]
950-
951945
def reconstruct(self, *args):
952946
return ListTensor(asarray(args).reshape(self.array.shape))
953947

@@ -958,7 +952,7 @@ def is_equal(self, other):
958952
"""Common subexpression eliminating equality predicate."""
959953
if type(self) is not type(other):
960954
return False
961-
if (self.array == other.array).all():
955+
if numpy.array_equal(self.array, other.array):
962956
self.array = other.array
963957
return True
964958
return False

gem/optimise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _constant_fold_zero_listtensor(node, self):
188188
new_children = list(map(self, node.children))
189189
if all(isinstance(nc, Zero) for nc in new_children):
190190
return Zero(node.shape)
191-
elif all(nc == c for nc, c in zip(new_children, node.children)):
191+
elif new_children == node.children:
192192
return node
193193
else:
194194
return node.reconstruct(*new_children)
@@ -207,7 +207,7 @@ def constant_fold_zero(exprs):
207207
otherwise Literal `0`s would be reintroduced.
208208
"""
209209
mapper = Memoizer(_constant_fold_zero)
210-
return [mapper(e) for e in exprs]
210+
return list(map(mapper, exprs))
211211

212212

213213
def _select_expression(expressions, index):
@@ -252,9 +252,9 @@ def child(expression):
252252
assert all(len(e.children) == len(expr.children) for e in expressions)
253253
assert len(expr.children) > 0
254254

255-
return expr.reconstruct(*[_select_expression(nth_children, index)
256-
for nth_children in zip(*[e.children
257-
for e in expressions])])
255+
return expr.reconstruct(*(_select_expression(nth_children, index)
256+
for nth_children in zip(*(e.children
257+
for e in expressions))))
258258

259259
raise NotImplementedError("No rule for factorising expressions of this kind.")
260260

test/finat/test_zany_mapping.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
from gem.interpreter import evaluate
6+
from finat.physically_mapped import PhysicallyMappedElement
67

78

89
def make_unisolvent_points(element, interior=False):
@@ -65,11 +66,11 @@ def check_zany_mapping(element, ref_to_phys, *args, **kwargs):
6566
# Zany map the results
6667
num_bfs = phys_element.space_dimension()
6768
num_dofs = finat_element.space_dimension()
68-
try:
69+
if isinstance(finat_element, PhysicallyMappedElement):
6970
Mgem = finat_element.basis_transformation(ref_to_phys)
7071
M = evaluate([Mgem])[0].arr
7172
ref_vals_zany = np.tensordot(M, ref_vals_piola, (-1, 0))
72-
except AttributeError:
73+
else:
7374
M = np.eye(num_dofs, num_bfs)
7475
ref_vals_zany = ref_vals_piola
7576

0 commit comments

Comments
 (0)