Skip to content

Commit f3fac54

Browse files
committed
Various BC tests passing
1 parent 36d8f7f commit f3fac54

File tree

2 files changed

+17
-56
lines changed

2 files changed

+17
-56
lines changed

firedrake/bcs.py

+16-55
Original file line numberDiff line numberDiff line change
@@ -99,49 +99,18 @@ def _indices(self):
9999
break
100100
return tuple(reversed(indices))
101101

102-
@utils.cached_property
103-
def nodes(self):
104-
'''The list of nodes at which this boundary condition applies.'''
105-
106-
def hermite_stride(bcnodes):
107-
if isinstance(self._function_space.finat_element, finat.Hermite) and \
108-
self._function_space.mesh().topological_dimension() == 1:
109-
return bcnodes[::2] # every second dof is the vertex value
110-
else:
111-
return bcnodes
112-
113-
sub_d = (self.sub_domain, ) if isinstance(self.sub_domain, str) else as_tuple(self.sub_domain)
114-
sub_d = [s if isinstance(s, str) else as_tuple(s) for s in sub_d]
115-
bcnodes = []
116-
for s in sub_d:
117-
if isinstance(s, str):
118-
bcnodes.append(hermite_stride(self._function_space.boundary_nodes(s)))
119-
else:
120-
# s is of one of the following formats:
121-
# facet: (i, )
122-
# edge: (i, j)
123-
# vertex: (i, j, k)
124-
# take intersection of facet nodes, and add it to bcnodes
125-
# i, j, k can also be strings.
126-
bcnodes1 = []
127-
if len(s) > 1 and not isinstance(self._function_space.finat_element, (finat.Lagrange, finat.GaussLobattoLegendre)):
128-
raise TypeError("Currently, edge conditions have only been tested with CG Lagrange elements")
129-
for ss in s:
130-
# intersection of facets
131-
# Edge conditions have only been tested with Lagrange elements.
132-
# Need to expand the list.
133-
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
134-
bcnodes1 = functools.reduce(np.intersect1d, bcnodes1)
135-
bcnodes.append(bcnodes1)
136-
return np.concatenate(bcnodes)
137-
138102
@utils.cached_property
139103
def constrained_points(self):
140104
"""Return the subset of mesh points constrained by the boundary condition."""
141105
# NOTE: This returns facets, whose closure is then used when applying the BC
142106
mesh = self._function_space.mesh().topology
143107
tdim = mesh.dimension
144108

109+
# 1D Hermite elements have strange vertex properties, we only want every
110+
# other entry
111+
if isinstance(self._function_space.finat_element, finat.Hermite) and tdim == 1:
112+
raise NotImplementedError("TODO, need to have inner slice with stride 2")
113+
145114
subset_data_per_dim = {
146115
dim: [] for dim in range(tdim + 1)
147116
}
@@ -166,7 +135,9 @@ def constrained_points(self):
166135
)
167136

168137
if len(subdomain_id) > 1:
169-
raise NotImplementedError("TODO pyop3")
138+
raise NotImplementedError(
139+
"TODO pyop3, need to intersect (see previous `nodes` method)"
140+
)
170141

171142
subsets = mesh.subdomain_points(subdomain_id)
172143
for dim, subset_data in subset_data_per_dim.items():
@@ -181,18 +152,11 @@ def constrained_points(self):
181152
for dim, data in flat_subset_data.items():
182153
point_label = str(dim)
183154
n, = data.shape
184-
array = op3.HierarchicalArray(op3.Axis(n), data=data, prefix="subset")
155+
array = op3.HierarchicalArray(op3.Axis(n), data=data, prefix="subset", dtype=utils.IntType)
185156
subset = op3.Subset(point_label, array)
186157
subsets.append(subset)
187158
return op3.Slice(mesh.points.label, subsets)
188159

189-
# @utils.cached_property
190-
# def node_set(self):
191-
# '''The subset corresponding to the nodes at which this
192-
# boundary condition applies.'''
193-
#
194-
# return self._function_space.axes[self.nodes]
195-
196160
@PETSc.Log.EventDecorator()
197161
def zero(self, r):
198162
r"""Zero the boundary condition nodes on ``r``.
@@ -210,12 +174,7 @@ def zero(self, r):
210174
# TODO raise an exception if spaces are not compatible
211175
# raise RuntimeError(f"{r} defined on incompatible FunctionSpace")
212176

213-
mesh = self._function_space.mesh().topology
214-
op3.do_loop(
215-
p := mesh.points[self.constrained_points].index(),
216-
r.dat[p].assign(0),
217-
)
218-
# r.dat.zero(subset=self.node_set)
177+
r.dat.eager_zero(subset=self.constrained_points)
219178

220179
@PETSc.Log.EventDecorator()
221180
def set(self, r, val):
@@ -226,10 +185,12 @@ def set(self, r, val):
226185

227186
for idx in self._indices:
228187
r = r.sub(idx)
229-
if not np.isscalar(val):
188+
if isinstance(val, firedrake.Cofunction):
230189
for idx in self._indices:
231190
val = val.sub(idx)
232-
r.assign(val, subset=self.node_set)
191+
else:
192+
assert np.isscalar(val)
193+
r.assign(val, subset=self.constrained_points)
233194

234195
def integrals(self):
235196
raise NotImplementedError("integrals() method has to be overwritten")
@@ -460,9 +421,9 @@ def apply(self, r, u=None):
460421
if u:
461422
u = u.sub(idx)
462423
if u:
463-
r.assign(u - self.function_arg, subset=self.node_set)
424+
r.assign(u - self.function_arg, subset=self.constrained_points)
464425
else:
465-
r.assign(self.function_arg, subset=self.node_set)
426+
r.assign(self.function_arg, subset=self.constrained_points)
466427

467428
def integrals(self):
468429
return []

firedrake/constant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _create_const(value, comm):
3232
for size in shape[1:]:
3333
axes = axes.add_subaxis(op3.Axis(size), *axes.leaf)
3434
axes = axes.set_up()
35-
dat = op3.HierarchicalArray(axes, data=data)
35+
dat = op3.HierarchicalArray(axes, data=data.flatten())
3636
return dat, rank, shape
3737

3838

0 commit comments

Comments
 (0)