Skip to content

Commit 0626555

Browse files
fix python and format
1 parent 2d49fec commit 0626555

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

include/ceed/jit-source/cuda/cuda-jit.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
#define CeedPragmaSIMD
1414
#define CEED_Q_VLA 1
1515

16-
#define CEED_QFUNCTION_RUST(name) \
16+
#define CEED_QFUNCTION_RUST(name) \
1717
extern "C" __device__ int name##_rs(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out); \
18-
static __device__ int name(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out) { \
19-
return name##_rs(ctx, Q, in, out); \
20-
}
18+
static __device__ int name(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out) { return name##_rs(ctx, Q, in, out); }
2119

2220
#include "cuda-types.h"

include/ceed/types.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ pub unsafe extern "C" fn build_mass_rs(
7373
static const char name##_loc[] = __FILE__ ":" #name; \
7474
CEED_QFUNCTION_ATTR int name##_rs(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out); \
7575
CEED_QFUNCTION_ATTR static int name(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out) { \
76-
return name##_rs(ctx, Q, in, out);}
76+
return name##_rs(ctx, Q, in, out); \
77+
} \
78+
7779
#endif
80+
// Note: the empty line at the end of the macro is required because python cffi will exclude the previous line (the }) based on the backslash at the end of it, which is required for our python build script to exclude macros
7881

7982
/**
8083
@ingroup CeedQFunction

python/build_ceed_cffi.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@
1313
ceed_version_ge = re.compile(r'\s+\(!?CEED_VERSION.*')
1414

1515

16+
# Checks to see if a c line is part of the lines we have to exclude (macros)
17+
def is_valid_line(line):
18+
if (line.startswith("#") and not line.startswith("#include")):
19+
return False
20+
if (line.startswith(" static")):
21+
return False
22+
if (line.startswith(" CEED_QFUNCTION_ATTR")):
23+
return False
24+
if (line.startswith(" return name##_rs")):
25+
return False
26+
if (line.endswith('\\\n')):
27+
return False
28+
if ("CeedErrorImpl" in line):
29+
return False
30+
if (r'const char *, ...);' in line):
31+
return False
32+
if (line.startswith("CEED_EXTERN const char *const")):
33+
return False
34+
if (ceed_version_ge.match(line)):
35+
return False
36+
return True
37+
1638
def get_ceed_dirs():
1739
here = os.path.dirname(os.path.abspath(__file__))
1840
prefix = os.path.dirname(here)
@@ -31,17 +53,10 @@ def get_ceed_dirs():
3153
lines = []
3254
for header_path in ["include/ceed/types.h", "include/ceed/ceed.h"]:
3355
with open(os.path.abspath(header_path)) as f:
34-
lines += [line.strip() for line in f if
35-
not (line.startswith("#") and not line.startswith("#include")) and
36-
not line.startswith(" static") and
37-
not line.startswith(" CEED_QFUNCTION_ATTR") and
38-
not line.startswith(" return name##_rs") and
39-
"CeedErrorImpl" not in line and
40-
"const char *, ...);" not in line and
41-
not line.startswith("CEED_EXTERN const char *const") and
42-
not ceed_version_ge.match(line)]
56+
lines += [line.strip() for line in f if is_valid_line(line)]
4357
lines = [line.replace("CEED_EXTERN", "extern") for line in lines]
44-
58+
print(lines)
59+
#breakpoint()
4560
# Find scalar type inclusion line and insert definitions
4661
for line in lines:
4762
if re.search("ceed-f32.h", line) is not None:

0 commit comments

Comments
 (0)