diff --git a/ci_scripts/check_api_label_cn.py b/ci_scripts/check_api_label_cn.py index 6d8e0b71fde..6496f60bf34 100644 --- a/ci_scripts/check_api_label_cn.py +++ b/ci_scripts/check_api_label_cn.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import argparse import logging -import os import re import sys from pathlib import Path +# precompile regex patterns +CN_API_LABEL_PATTERN = re.compile(r".. _([a-zA-Z0-9_]+):") +REF_PATTERN = re.compile(r":ref:`([^`]+)`") +API_LABEL_EXTRACT_PATTERN = re.compile(r".+?<(?P.+?)>") +CN_API_PREFIX = "cn_api_paddle" + logger = logging.getLogger() if logger.handlers: # we assume the first handler is the one we want to configure @@ -20,87 +27,141 @@ logger.setLevel(logging.INFO) -# check file's api_label -def check_api_label(rootdir, file): - real_file = Path(rootdir) / file - with open(real_file, "r", encoding="utf-8") as f: +def check_api_label(file_path: Path, doc_root: Path) -> str | None: + """Check if the first line of the file matches the expected api_label format.""" + with open(file_path, "r", encoding="utf-8") as f: first_line = f.readline().strip() - return first_line == generate_en_label_by_path(file) - - -# path -> api_label (the first line's style) -def generate_en_label_by_path(file): - result = file.removesuffix("_cn.rst") - result = "_".join(Path(result).parts) - result = f".. _cn_{result}:" - return result - - -# traverse doc/api to append api_label in list -def find_all_api_labels_in_dir(rootdir): - all_api_labels = [] - for root, dirs, files in os.walk(rootdir + API): - for file in files: - real_path = Path(root) / file - path = str(real_path).removeprefix(rootdir) - if not should_test(path): - continue - for label in find_api_labels_in_one_file(real_path): - all_api_labels.append(label) - return all_api_labels + except_line = generate_cn_label(file_path, doc_root) + if first_line == except_line: + return None + return f"`{file_path}` first line is: `{first_line}`, but expected generated by path: `{except_line}`." + + +def generate_cn_label(file_path: Path, doc_root: Path) -> str: + """Generate the expected api_label format from file path.""" + relative_path = file_path.relative_to(doc_root) + stem = relative_path.stem.removesuffix("_cn") + parts = relative_path.with_name(stem).parts + label = "_".join(parts) + return f".. _cn_{label}:" + + +def collect_api_labels(api_root: Path) -> set[str]: + """Collect all valid api labels.""" + labels = set() + for rst_file in api_root.rglob("*.rst"): + if not rst_file.is_file(): + continue + if not need_check(rst_file, api_root): + continue + labels.update(extract_api_labels(rst_file)) + return labels -# api_labels in a file -def find_api_labels_in_one_file(file_path): - api_labels_in_one_file = [] +def extract_api_labels(file_path: Path) -> set[str]: + labels = set() with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: - line = re.search(".. _([a-zA-Z0-9_]+)", line) - if not line: + match = CN_API_LABEL_PATTERN.search(line) + if not match: continue - api_labels_in_one_file.append(line.group(1)) - return api_labels_in_one_file + label = match.group(1) + if not label.startswith("cn_api_paddle"): + continue + labels.add(label) + return labels # api doc for checking -def should_test(file): +def need_check(file_path: Path, api_root: Path) -> bool: return ( - file.endswith("_cn.rst") - and not Path(file).name == "Overview_cn.rst" - and not Path(file).name == "index_cn.rst" - and file.startswith(API) + file_path.name.endswith("_cn.rst") + and file_path.name not in {"Overview_cn.rst", "index_cn.rst"} + and file_path.is_relative_to(api_root) ) -def run_cn_api_label_checking(rootdir, files): - for file in files: - if should_test(file) and not check_api_label(rootdir, file): - logger.error( - f"The first line in {rootdir}/{file} is not avaiable, please re-check it!" - ) - sys.exit(1) - valid_api_labels = find_all_api_labels_in_dir(rootdir) +def validate_api_label_references( + files: list[Path], valid_api_labels: set[str] +) -> list[str]: + errors = [] for file in files: - if not file.endswith(".rst"): - continue - with open(Path(rootdir) / file, "r", encoding="utf-8") as f: - pattern = f.read() - matches = re.findall(r":ref:`([^`]+)`", pattern) + with open(file, "r", encoding="utf-8") as f: + content = f.read() + matches = REF_PATTERN.findall(content) for match in matches: api_label = match - if api_label_match := re.match( - r".+<(?P.+?)>", api_label - ): + if api_label_match := API_LABEL_EXTRACT_PATTERN.match(api_label): api_label = api_label_match.group("api_label") - if ( - api_label.startswith("cn_api_paddle") - and api_label not in valid_api_labels - ): - logger.error( - f"Found api label {api_label} in {rootdir}/{file}, but it is not a valid api label, please re-check it!" - ) - sys.exit(1) + if not api_label.startswith(CN_API_PREFIX): + continue + if api_label in valid_api_labels: + continue + errors.append(f"api label `{api_label}` in `{file}`") + return errors + + +def get_custom_files_for_checking_usage(api_root: Path) -> set[Path]: + # TODO: add more dir for checking + custom_files = set() + for rst_file in api_root.rglob("*.rst"): + if not rst_file.is_file(): + continue + if rst_file.name in {"set_global_initializer_cn.rst"}: + # TODO: how to deal with `api_paddle_Tensor_create_tensor`? + continue + custom_files.add(rst_file) + return custom_files + + +def run_cn_api_label_checking( + doc_root: Path, api_root: Path, files: list[Path] +) -> None: + # get real path for changed files + real_path_files_set = set(files) + + errors = [] + # check the api_label in the first line for increased files + for file_path in real_path_files_set: + if not need_check(file_path, api_root): + continue + if error := check_api_label(file_path, doc_root): + errors.append(error) + if errors: + logger.error( + "Found first line is not available as follows, please re-check it!" + ) + for i, error in enumerate(errors, 1): + logger.error(f"{i}: {error}") + sys.exit(1) + + # collect all api_labels in api_root + valid_api_labels = collect_api_labels(api_root) + + # check the usage of api_label in custom files + need_uasge_check_files = set() + for file_path in real_path_files_set: + if not file_path.is_relative_to(doc_root): + continue + if file_path.suffix != ".rst": + continue + need_uasge_check_files.add(file_path) + + api_label_usage_file_set = ( + need_uasge_check_files | get_custom_files_for_checking_usage(doc_root) + ) + + if errors := validate_api_label_references( + api_label_usage_file_set, valid_api_labels + ): + logger.error( + "Found valid api labels usage as follows, please re-check it!" + ) + for i, error in enumerate(errors, 1): + logger.error(f"{i}: {error}") + sys.exit(1) + print("All api_label check success in PR !") @@ -110,29 +171,27 @@ def parse_args(): """ parser = argparse.ArgumentParser(description="cn api_label checking") parser.add_argument( - "rootdir", + "doc_root", + type=Path, help="the dir DOCROOT", - type=str, - default="/FluidDoc/docs/", + default=Path("/FluidDoc/docs"), ) parser.add_argument( - "apiroot", - type=str, - help="the dir APIROOT", - default="/FluidDoc/docs/api/", + "api_root", + type=Path, + help="the dir api_root", + default=Path("/FluidDoc/docs/api"), ) parser.add_argument( "all_git_files", - type=str, + type=Path, nargs="*", help="files need to check", ) - args = parser.parse_args() - return args + return parser.parse_args() if __name__ == "__main__": args = parse_args() - API = args.apiroot.removeprefix(args.rootdir + "/") - run_cn_api_label_checking(args.rootdir, args.all_git_files) + run_cn_api_label_checking(args.doc_root, args.api_root, args.all_git_files) diff --git a/ci_scripts/check_api_label_cn.sh b/ci_scripts/check_api_label_cn.sh index 33d206420f8..c20d2626e18 100644 --- a/ci_scripts/check_api_label_cn.sh +++ b/ci_scripts/check_api_label_cn.sh @@ -3,8 +3,8 @@ set -x FLUIDDOCDIR=${FLUIDDOCDIR:=/FluidDoc} -DOCROOT=${FLUIDDOCDIR}/docs/ -APIROOT=${DOCROOT}/api/ +DOCROOT=${FLUIDDOCDIR}/docs +APIROOT=${DOCROOT}/api SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" source ${SCRIPT_DIR}/utils.sh @@ -13,7 +13,10 @@ if [ -z ${BRANCH} ]; then BRANCH="develop" fi -all_git_files=`git diff --name-only --diff-filter=ACMR upstream/${BRANCH} | sed 's#docs/##g'` +all_git_files=$(git diff --name-only --diff-filter=ACMR upstream/${BRANCH}) + +real_path_git_files=$(echo ${all_git_files} | sed "s|^|${FLUIDDOCDIR}/|g") + echo $all_git_files echo "Run API_LABEL Checking" -python check_api_label_cn.py ${DOCROOT} ${APIROOT} $all_git_files +python check_api_label_cn.py ${DOCROOT} ${APIROOT} $real_path_git_files