@@ -99,49 +99,18 @@ def _indices(self):
99
99
break
100
100
return tuple (reversed (indices ))
101
101
102
- @utils .cached_property
103
- def nodes (self ):
104
- '''The list of nodes at which this boundary condition applies.'''
105
-
106
- def hermite_stride (bcnodes ):
107
- if isinstance (self ._function_space .finat_element , finat .Hermite ) and \
108
- self ._function_space .mesh ().topological_dimension () == 1 :
109
- return bcnodes [::2 ] # every second dof is the vertex value
110
- else :
111
- return bcnodes
112
-
113
- sub_d = (self .sub_domain , ) if isinstance (self .sub_domain , str ) else as_tuple (self .sub_domain )
114
- sub_d = [s if isinstance (s , str ) else as_tuple (s ) for s in sub_d ]
115
- bcnodes = []
116
- for s in sub_d :
117
- if isinstance (s , str ):
118
- bcnodes .append (hermite_stride (self ._function_space .boundary_nodes (s )))
119
- else :
120
- # s is of one of the following formats:
121
- # facet: (i, )
122
- # edge: (i, j)
123
- # vertex: (i, j, k)
124
- # take intersection of facet nodes, and add it to bcnodes
125
- # i, j, k can also be strings.
126
- bcnodes1 = []
127
- if len (s ) > 1 and not isinstance (self ._function_space .finat_element , (finat .Lagrange , finat .GaussLobattoLegendre )):
128
- raise TypeError ("Currently, edge conditions have only been tested with CG Lagrange elements" )
129
- for ss in s :
130
- # intersection of facets
131
- # Edge conditions have only been tested with Lagrange elements.
132
- # Need to expand the list.
133
- bcnodes1 .append (hermite_stride (self ._function_space .boundary_nodes (ss )))
134
- bcnodes1 = functools .reduce (np .intersect1d , bcnodes1 )
135
- bcnodes .append (bcnodes1 )
136
- return np .concatenate (bcnodes )
137
-
138
102
@utils .cached_property
139
103
def constrained_points (self ):
140
104
"""Return the subset of mesh points constrained by the boundary condition."""
141
105
# NOTE: This returns facets, whose closure is then used when applying the BC
142
106
mesh = self ._function_space .mesh ().topology
143
107
tdim = mesh .dimension
144
108
109
+ # 1D Hermite elements have strange vertex properties, we only want every
110
+ # other entry
111
+ if isinstance (self ._function_space .finat_element , finat .Hermite ) and tdim == 1 :
112
+ raise NotImplementedError ("TODO, need to have inner slice with stride 2" )
113
+
145
114
subset_data_per_dim = {
146
115
dim : [] for dim in range (tdim + 1 )
147
116
}
@@ -166,7 +135,9 @@ def constrained_points(self):
166
135
)
167
136
168
137
if len (subdomain_id ) > 1 :
169
- raise NotImplementedError ("TODO pyop3" )
138
+ raise NotImplementedError (
139
+ "TODO pyop3, need to intersect (see previous `nodes` method)"
140
+ )
170
141
171
142
subsets = mesh .subdomain_points (subdomain_id )
172
143
for dim , subset_data in subset_data_per_dim .items ():
@@ -181,18 +152,11 @@ def constrained_points(self):
181
152
for dim , data in flat_subset_data .items ():
182
153
point_label = str (dim )
183
154
n , = data .shape
184
- array = op3 .HierarchicalArray (op3 .Axis (n ), data = data , prefix = "subset" )
155
+ array = op3 .HierarchicalArray (op3 .Axis (n ), data = data , prefix = "subset" , dtype = utils . IntType )
185
156
subset = op3 .Subset (point_label , array )
186
157
subsets .append (subset )
187
158
return op3 .Slice (mesh .points .label , subsets )
188
159
189
- # @utils.cached_property
190
- # def node_set(self):
191
- # '''The subset corresponding to the nodes at which this
192
- # boundary condition applies.'''
193
- #
194
- # return self._function_space.axes[self.nodes]
195
-
196
160
@PETSc .Log .EventDecorator ()
197
161
def zero (self , r ):
198
162
r"""Zero the boundary condition nodes on ``r``.
@@ -210,12 +174,7 @@ def zero(self, r):
210
174
# TODO raise an exception if spaces are not compatible
211
175
# raise RuntimeError(f"{r} defined on incompatible FunctionSpace")
212
176
213
- mesh = self ._function_space .mesh ().topology
214
- op3 .do_loop (
215
- p := mesh .points [self .constrained_points ].index (),
216
- r .dat [p ].assign (0 ),
217
- )
218
- # r.dat.zero(subset=self.node_set)
177
+ r .dat .eager_zero (subset = self .constrained_points )
219
178
220
179
@PETSc .Log .EventDecorator ()
221
180
def set (self , r , val ):
@@ -226,10 +185,12 @@ def set(self, r, val):
226
185
227
186
for idx in self ._indices :
228
187
r = r .sub (idx )
229
- if not np . isscalar (val ):
188
+ if isinstance (val , firedrake . Cofunction ):
230
189
for idx in self ._indices :
231
190
val = val .sub (idx )
232
- r .assign (val , subset = self .node_set )
191
+ else :
192
+ assert np .isscalar (val )
193
+ r .assign (val , subset = self .constrained_points )
233
194
234
195
def integrals (self ):
235
196
raise NotImplementedError ("integrals() method has to be overwritten" )
@@ -460,9 +421,9 @@ def apply(self, r, u=None):
460
421
if u :
461
422
u = u .sub (idx )
462
423
if u :
463
- r .assign (u - self .function_arg , subset = self .node_set )
424
+ r .assign (u - self .function_arg , subset = self .constrained_points )
464
425
else :
465
- r .assign (self .function_arg , subset = self .node_set )
426
+ r .assign (self .function_arg , subset = self .constrained_points )
466
427
467
428
def integrals (self ):
468
429
return []
0 commit comments