Skip to content

Commit f1283eb

Browse files
SvenDS9facebook-github-bot
authored andcommitted
CI test: check if documentation is complete (#1070)
Summary: Fixes #1047 Basic but should do the trick. Not sure about the Github/actions part (never used this before) ### Changes - Add CI test that checks if all exported dps have documentation Pull Request resolved: #1070 Reviewed By: NivekT Differential Revision: D43876159 Pulled By: ejguan fbshipit-source-id: 3b77416ca8f082cff90abec981aa00ddc4232db8
1 parent 431ca6f commit f1283eb

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

.github/scripts/check_complete_doc.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import sys
9+
10+
11+
def collect_init_dps(init_file_location):
12+
init_dps = set()
13+
with open(init_file_location) as init_file:
14+
while (line := init_file.readline()) != "":
15+
if line.startswith("__all__ "):
16+
while (line := init_file.readline()) != "" and (stripped_line := line.strip()).startswith('"'):
17+
init_dps.add(stripped_line.replace(",", "").replace('"', ""))
18+
break
19+
return init_dps
20+
21+
22+
def collect_rst_dps(rst_file_location):
23+
rst_dps = set()
24+
with open(rst_file_location) as rst_file:
25+
while (line := rst_file.readline()) != "":
26+
if line.count("class_template.rst") > 0 or line.count("function.rst") > 0:
27+
rst_file.readline()
28+
while (line := rst_file.readline()) != "" and len(stripped_line := line.strip()) > 1:
29+
rst_dps.add(stripped_line)
30+
return rst_dps
31+
32+
33+
def compare_sets(set_a, set_b, ignore_set=None):
34+
res = set_a.difference(set_b)
35+
if ignore_set is not None:
36+
res.difference_update(ignore_set)
37+
return res
38+
39+
40+
def main():
41+
datapipes_folder = os.path.join("torchdata", "datapipes")
42+
init_file = "__init__.py"
43+
docs_source_folder = os.path.join("docs", "source")
44+
exit_code = 0
45+
46+
for target, ignore_set in zip(["iter", "map", "utils"], [{"IterDataPipe", "Extractor"}, {"MapDataPipe"}, {}]):
47+
init_path = os.path.join(datapipes_folder, target, init_file)
48+
rst_path = os.path.join(docs_source_folder, "torchdata.datapipes." + target + ".rst")
49+
50+
init_set = collect_init_dps(init_path)
51+
rst_set = collect_rst_dps(rst_path)
52+
53+
dif_init = compare_sets(init_set, rst_set, ignore_set)
54+
dif_rst = compare_sets(rst_set, init_set)
55+
56+
for elem in dif_init:
57+
print(f"Please add {elem} to {rst_path}")
58+
exit_code = 1
59+
for elem in dif_rst:
60+
print(f"{elem} is present in {rst_path} but not in {init_path}")
61+
exit_code = 1
62+
63+
sys.exit(exit_code)
64+
65+
66+
if __name__ == "__main__":
67+
main()

.github/workflows/lint.yml

+13
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,16 @@ jobs:
7373
echo 'Please fix the above mypy warnings.'
7474
false
7575
fi
76+
77+
complete_documentation:
78+
if: ${{ github.repository_owner == 'pytorch' }}
79+
runs-on: ubuntu-latest
80+
steps:
81+
- name: Setup Python
82+
uses: actions/setup-python@v4
83+
with:
84+
python-version: "3.8"
85+
- name: Check out source repository
86+
uses: actions/checkout@v3
87+
- name: Check if documentation is complete
88+
run: python ./.github/scripts/check_complete_doc.py

docs/source/torchdata.datapipes.utils.rst

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Common Utility Functions
2424
:toctree: generated/
2525
:template: function.rst
2626

27+
janitor
2728
pin_memory_fn
2829

2930

torchdata/datapipes/utils/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@
1010
from torchdata.datapipes.utils.janitor import janitor
1111
from torchdata.datapipes.utils.pin_memory import pin_memory_fn
1212

13-
__all__ = ["StreamWrapper", "janitor", "pin_memory_fn", "to_graph"]
13+
__all__ = [
14+
"StreamWrapper",
15+
"janitor",
16+
"pin_memory_fn",
17+
"to_graph",
18+
]
19+
20+
# Please keep this list sorted
21+
assert __all__ == sorted(__all__)

0 commit comments

Comments
 (0)