|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | + |
| 10 | + |
| 11 | +def collect_init_dps(init_file_location): |
| 12 | + init_dps = set() |
| 13 | + with open(init_file_location) as init_file: |
| 14 | + while (line := init_file.readline()) != "": |
| 15 | + if line.startswith("__all__ "): |
| 16 | + while (line := init_file.readline()) != "" and (stripped_line := line.strip()).startswith('"'): |
| 17 | + init_dps.add(stripped_line.replace(",", "").replace('"', "")) |
| 18 | + break |
| 19 | + return init_dps |
| 20 | + |
| 21 | + |
| 22 | +def collect_rst_dps(rst_file_location): |
| 23 | + rst_dps = set() |
| 24 | + with open(rst_file_location) as rst_file: |
| 25 | + while (line := rst_file.readline()) != "": |
| 26 | + if line.count("class_template.rst") > 0 or line.count("function.rst") > 0: |
| 27 | + rst_file.readline() |
| 28 | + while (line := rst_file.readline()) != "" and len(stripped_line := line.strip()) > 1: |
| 29 | + rst_dps.add(stripped_line) |
| 30 | + return rst_dps |
| 31 | + |
| 32 | + |
| 33 | +def compare_sets(set_a, set_b, ignore_set=None): |
| 34 | + res = set_a.difference(set_b) |
| 35 | + if ignore_set is not None: |
| 36 | + res.difference_update(ignore_set) |
| 37 | + return res |
| 38 | + |
| 39 | + |
| 40 | +def main(): |
| 41 | + datapipes_folder = os.path.join("torchdata", "datapipes") |
| 42 | + init_file = "__init__.py" |
| 43 | + docs_source_folder = os.path.join("docs", "source") |
| 44 | + exit_code = 0 |
| 45 | + |
| 46 | + for target, ignore_set in zip(["iter", "map", "utils"], [{"IterDataPipe", "Extractor"}, {"MapDataPipe"}, {}]): |
| 47 | + init_path = os.path.join(datapipes_folder, target, init_file) |
| 48 | + rst_path = os.path.join(docs_source_folder, "torchdata.datapipes." + target + ".rst") |
| 49 | + |
| 50 | + init_set = collect_init_dps(init_path) |
| 51 | + rst_set = collect_rst_dps(rst_path) |
| 52 | + |
| 53 | + dif_init = compare_sets(init_set, rst_set, ignore_set) |
| 54 | + dif_rst = compare_sets(rst_set, init_set) |
| 55 | + |
| 56 | + for elem in dif_init: |
| 57 | + print(f"Please add {elem} to {rst_path}") |
| 58 | + exit_code = 1 |
| 59 | + for elem in dif_rst: |
| 60 | + print(f"{elem} is present in {rst_path} but not in {init_path}") |
| 61 | + exit_code = 1 |
| 62 | + |
| 63 | + sys.exit(exit_code) |
| 64 | + |
| 65 | + |
| 66 | +if __name__ == "__main__": |
| 67 | + main() |
0 commit comments