forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwrap_headers.py
More file actions
61 lines (45 loc) · 1.76 KB
/
wrap_headers.py
File metadata and controls
61 lines (45 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Wrap installed headers with TORCH_STABLE_ONLY / TORCH_TARGET_VERSION guards.
Headers under the include directory are wrapped so they emit a compile error
when included with TORCH_STABLE_ONLY or TORCH_TARGET_VERSION defined.
Certain directories (stable API, headeronly, AOTI shims) are excluded.
Called at install time by cmake/PostBuildSteps.cmake.
"""
import argparse
import pathlib
HEADER_EXTENSIONS = (".h", ".hpp", ".cuh")
EXCLUDE_PATTERNS = (
"torch/headeronly/",
"torch/csrc/stable/",
"torch/csrc/inductor/aoti_torch/c/",
"torch/csrc/inductor/aoti_torch/generated/",
)
WRAP_MARKER = "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)"
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"include_dir", type=pathlib.Path, help="Installed include directory"
)
args = parser.parse_args()
include_dir = args.include_dir
if not include_dir.exists():
return
for header in sorted(include_dir.rglob("*")):
if header.suffix not in HEADER_EXTENSIONS:
continue
rel = header.relative_to(include_dir).as_posix()
if any(rel.startswith(pat) for pat in EXCLUDE_PATTERNS):
continue
content = header.read_text(encoding="utf-8")
if content.startswith(WRAP_MARKER):
continue
wrapped = (
f"{WRAP_MARKER}\n"
f"{content}\n"
"#else\n"
'#error "This file should not be included when either '
'TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n'
"#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n"
)
header.write_text(wrapped, encoding="utf-8")
if __name__ == "__main__":
main()