Skip to content

Commit 5f18075

Browse files
composed map: add permute method (#723)
--------- Co-authored-by: Connor Ward <[email protected]>
1 parent af813e9 commit 5f18075

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

pyop2/codegen/builder.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ def shape(self):
7575
def dtype(self):
7676
return self.values.dtype
7777

78-
def indexed(self, multiindex, layer=None, permute=lambda x: x):
78+
def _permute(self, x):
79+
return x
80+
81+
def indexed(self, multiindex, layer=None):
7982
n, i, f = multiindex
8083
if layer is not None and self.offset is not None:
8184
# For extruded mesh, prefetch the indirections for each map, so that they don't
@@ -84,7 +87,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
8487
base_key = None
8588
if base_key not in self.prefetch:
8689
j = Index()
87-
base = Indexed(self.values, (n, permute(j)))
90+
base = Indexed(self.values, (n, self._permute(j)))
8891
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))
8992

9093
base = self.prefetch[base_key]
@@ -122,17 +125,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
122125
return Indexed(self.prefetch[key], (f, i)), (f, i)
123126
else:
124127
assert f.extent == 1 or f.extent is None
125-
base = Indexed(self.values, (n, permute(i)))
128+
base = Indexed(self.values, (n, self._permute(i)))
126129
return base, (f, i)
127130

128-
def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
131+
def indexed_vector(self, n, shape, layer=None):
129132
shape = self.shape[1:] + shape
130133
if self.interior_horizontal:
131134
shape = (2, ) + shape
132135
else:
133136
shape = (1, ) + shape
134137
f, i, j = (Index(e) for e in shape)
135-
base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
138+
base, (f, i) = self.indexed((n, i, f), layer=layer)
136139
init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
137140
pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
138141
multiindex = tuple(Index(e) for e in pack.shape)
@@ -168,13 +171,8 @@ def __init__(self, map_, permutation):
168171
self.offset_quotient = map_.offset_quotient
169172
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")
170173

171-
def indexed(self, multiindex, layer=None):
172-
permute = lambda x: Indexed(self.permutation, (x,))
173-
return super().indexed(multiindex, layer=layer, permute=permute)
174-
175-
def indexed_vector(self, n, shape, layer=None):
176-
permute = lambda x: Indexed(self.permutation, (x,))
177-
return super().indexed_vector(n, shape, layer=layer, permute=permute)
174+
def _permute(self, x):
175+
return Indexed(self.permutation, (x,))
178176

179177

180178
class CMap(Map):

0 commit comments

Comments
 (0)