Skip to content

Commit afa95cf

Browse files
authored
Merge pull request #1611 from CEED/jrwrigh/project_mixed_tensor
basis: Allow CreateProjection for mixed-tensor bases
2 parents 3ff7c56 + e104ad1 commit afa95cf

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

interface/ceed-basis.c

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ static int CeedScalarView(const char *name, const char *fp_fmt, CeedInt m, CeedI
195195
**/
196196
static int CeedBasisCreateProjectionMatrices(CeedBasis basis_from, CeedBasis basis_to, CeedScalar **interp_project, CeedScalar **grad_project) {
197197
Ceed ceed;
198-
bool is_tensor_to, is_tensor_from;
198+
bool are_both_tensor;
199199
CeedInt Q, Q_to, Q_from, P_to, P_from;
200200

201201
CeedCall(CeedBasisGetCeed(basis_to, &ceed));
@@ -207,10 +207,14 @@ static int CeedBasisCreateProjectionMatrices(CeedBasis basis_from, CeedBasis bas
207207
Q = Q_to;
208208

209209
// Check for matching tensor or non-tensor
210-
CeedCall(CeedBasisIsTensor(basis_to, &is_tensor_to));
211-
CeedCall(CeedBasisIsTensor(basis_from, &is_tensor_from));
212-
CeedCheck(is_tensor_to == is_tensor_from, ceed, CEED_ERROR_MINOR, "Bases must both be tensor or non-tensor");
213-
if (is_tensor_to) {
210+
{
211+
bool is_tensor_to, is_tensor_from;
212+
213+
CeedCall(CeedBasisIsTensor(basis_to, &is_tensor_to));
214+
CeedCall(CeedBasisIsTensor(basis_from, &is_tensor_from));
215+
are_both_tensor = is_tensor_to && is_tensor_from;
216+
}
217+
if (are_both_tensor) {
214218
CeedCall(CeedBasisGetNumNodes1D(basis_to, &P_to));
215219
CeedCall(CeedBasisGetNumNodes1D(basis_from, &P_from));
216220
CeedCall(CeedBasisGetNumQuadraturePoints1D(basis_from, &Q));
@@ -231,7 +235,7 @@ static int CeedBasisCreateProjectionMatrices(CeedBasis basis_from, CeedBasis bas
231235
const CeedScalar *interp_to_source = NULL, *interp_from_source = NULL, *grad_from_source = NULL;
232236

233237
CeedCall(CeedBasisGetDimension(basis_to, &dim));
234-
if (is_tensor_to) {
238+
if (are_both_tensor) {
235239
CeedCall(CeedBasisGetInterp1D(basis_to, &interp_to_source));
236240
CeedCall(CeedBasisGetInterp1D(basis_from, &interp_from_source));
237241
} else {
@@ -246,19 +250,19 @@ static int CeedBasisCreateProjectionMatrices(CeedBasis basis_from, CeedBasis bas
246250
// projection basis will have a gradient operation (allocated even if not H^1 for the
247251
// basis construction later on)
248252
if (fe_space_to == CEED_FE_SPACE_H1) {
249-
if (is_tensor_to) {
253+
if (are_both_tensor) {
250254
CeedCall(CeedBasisGetGrad1D(basis_from, &grad_from_source));
251255
} else {
252256
CeedCall(CeedBasisGetGrad(basis_from, &grad_from_source));
253257
}
254258
}
255-
CeedCall(CeedCalloc(P_to * P_from * (is_tensor_to ? 1 : dim), grad_project));
259+
CeedCall(CeedCalloc(P_to * P_from * (are_both_tensor ? 1 : dim), grad_project));
256260

257261
// Compute interp_to^+, pseudoinverse of interp_to
258262
CeedCall(CeedCalloc(Q * q_comp * P_to, &interp_to_inv));
259263
CeedCall(CeedMatrixPseudoinverse(ceed, interp_to_source, Q * q_comp, P_to, interp_to_inv));
260264
// Build matrices
261-
CeedInt num_matrices = 1 + (fe_space_to == CEED_FE_SPACE_H1) * (is_tensor_to ? 1 : dim);
265+
CeedInt num_matrices = 1 + (fe_space_to == CEED_FE_SPACE_H1) * (are_both_tensor ? 1 : dim);
262266
CeedScalar *input_from[num_matrices], *output_project[num_matrices];
263267

264268
input_from[0] = (CeedScalar *)interp_from_source;
@@ -1322,6 +1326,8 @@ int CeedBasisCreateHcurl(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, Cee
13221326
Note: `basis_project` will have the same number of components as `basis_from`, regardless of the number of components that `basis_to` has.
13231327
If `basis_from` has 3 components and `basis_to` has 5 components, then `basis_project` will have 3 components.
13241328
1329+
Note: If either `basis_from` or `basis_to` are non-tensor, then `basis_project` will also be non-tensor
1330+
13251331
@param[in] basis_from `CeedBasis` to prolong from
13261332
@param[in] basis_to `CeedBasis` to prolong to
13271333
@param[out] basis_project Address of the variable where the newly created `CeedBasis` will be stored
@@ -1332,7 +1338,7 @@ int CeedBasisCreateHcurl(Ceed ceed, CeedElemTopology topo, CeedInt num_comp, Cee
13321338
**/
13331339
int CeedBasisCreateProjection(CeedBasis basis_from, CeedBasis basis_to, CeedBasis *basis_project) {
13341340
Ceed ceed;
1335-
bool is_tensor;
1341+
bool create_tensor;
13361342
CeedInt dim, num_comp;
13371343
CeedScalar *interp_project, *grad_project;
13381344

@@ -1342,10 +1348,16 @@ int CeedBasisCreateProjection(CeedBasis basis_from, CeedBasis basis_to, CeedBasi
13421348
CeedCall(CeedBasisCreateProjectionMatrices(basis_from, basis_to, &interp_project, &grad_project));
13431349

13441350
// Build basis
1345-
CeedCall(CeedBasisIsTensor(basis_to, &is_tensor));
1351+
{
1352+
bool is_tensor_to, is_tensor_from;
1353+
1354+
CeedCall(CeedBasisIsTensor(basis_to, &is_tensor_to));
1355+
CeedCall(CeedBasisIsTensor(basis_from, &is_tensor_from));
1356+
create_tensor = is_tensor_from && is_tensor_to;
1357+
}
13461358
CeedCall(CeedBasisGetDimension(basis_to, &dim));
13471359
CeedCall(CeedBasisGetNumComponents(basis_from, &num_comp));
1348-
if (is_tensor) {
1360+
if (create_tensor) {
13491361
CeedInt P_1d_to, P_1d_from;
13501362

13511363
CeedCall(CeedBasisGetNumNodes1D(basis_from, &P_1d_from));

tests/t319-basis.c

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ int main(int argc, char **argv) {
150150

151151
VerifyProjectedBasis(basis_project, dim, p_to_dim, p_from_dim, x_to, x_from, u_to, u_from, du_to);
152152

153-
// Test projection on non-tensor bases
153+
// Create non-tensor bases
154+
CeedBasis basis_from_nontensor, basis_to_nontensor;
154155
{
155-
CeedBasis basis_from_nontensor, basis_to_nontensor;
156156
CeedElemTopology topo;
157157
CeedInt num_comp, num_nodes, nqpts;
158158
const CeedScalar *interp, *grad;
@@ -172,14 +172,21 @@ int main(int argc, char **argv) {
172172
CeedBasisGetInterp(basis_to, &interp);
173173
CeedBasisGetGrad(basis_to, &grad);
174174
CeedBasisCreateH1(ceed, topo, num_comp, num_nodes, nqpts, interp, grad, NULL, NULL, &basis_to_nontensor);
175+
}
175176

176-
CeedBasisDestroy(&basis_project);
177-
CeedBasisCreateProjection(basis_from_nontensor, basis_to_nontensor, &basis_project);
177+
// Test projection on non-tensor bases
178+
CeedBasisDestroy(&basis_project);
179+
CeedBasisCreateProjection(basis_from_nontensor, basis_to_nontensor, &basis_project);
180+
VerifyProjectedBasis(basis_project, dim, p_to_dim, p_from_dim, x_to, x_from, u_to, u_from, du_to);
178181

179-
CeedBasisDestroy(&basis_to_nontensor);
180-
CeedBasisDestroy(&basis_from_nontensor);
181-
}
182+
// Test projection from non-tensor to tensor
183+
CeedBasisDestroy(&basis_project);
184+
CeedBasisCreateProjection(basis_from_nontensor, basis_to, &basis_project);
185+
VerifyProjectedBasis(basis_project, dim, p_to_dim, p_from_dim, x_to, x_from, u_to, u_from, du_to);
182186

187+
// Test projection from tensor to non-tensor
188+
CeedBasisDestroy(&basis_project);
189+
CeedBasisCreateProjection(basis_from, basis_to_nontensor, &basis_project);
183190
VerifyProjectedBasis(basis_project, dim, p_to_dim, p_from_dim, x_to, x_from, u_to, u_from, du_to);
184191

185192
CeedVectorDestroy(&x_corners);
@@ -189,7 +196,9 @@ int main(int argc, char **argv) {
189196
CeedVectorDestroy(&u_to);
190197
CeedVectorDestroy(&du_to);
191198
CeedBasisDestroy(&basis_from);
199+
CeedBasisDestroy(&basis_from_nontensor);
192200
CeedBasisDestroy(&basis_to);
201+
CeedBasisDestroy(&basis_to_nontensor);
193202
CeedBasisDestroy(&basis_project);
194203
}
195204
CeedDestroy(&ceed);

0 commit comments

Comments
 (0)