diff --git a/.github/workflows/pr-format.yml b/.github/workflows/pr-format.yml index 1895c22a..1d255ecf 100644 --- a/.github/workflows/pr-format.yml +++ b/.github/workflows/pr-format.yml @@ -27,10 +27,12 @@ jobs: run: uv venv --python 3.11 - name: Install lint dependencies - run: uv pip install --python .venv/bin/python pre-commit black isort + run: uv pip install --python .venv/bin/python ruff - name: Run lint - run: .venv/bin/pre-commit run --all-files + run: | + .venv/bin/python -m ruff check . + .venv/bin/python -m ruff format --check . unit_tests: needs: lint @@ -47,6 +49,7 @@ jobs: install-target: ".[cloud,vllm]" env: TMPDIR: /mnt/tmp + JAVA_TOOL_OPTIONS: --add-modules=jdk.incubator.vector steps: - uses: actions/checkout@v4 @@ -81,6 +84,7 @@ jobs: runs-on: ubuntu-latest env: TMPDIR: /mnt/tmp + JAVA_TOOL_OPTIONS: --add-modules=jdk.incubator.vector steps: - uses: actions/checkout@v4 @@ -118,3 +122,35 @@ jobs: test.test_cli_http \ test.test_cli_mcp \ test.test_cli_legacy_wrappers + + mypy: + needs: + - unit_tests + - cli_smoke + runs-on: ubuntu-latest + env: + TMPDIR: /mnt/tmp + steps: + - uses: actions/checkout@v4 + + - name: Prepare temp directory on /mnt + run: | + sudo mkdir -p "$TMPDIR" + sudo chown "$USER":"$USER" "$TMPDIR" + df -h / + df -h /mnt + + - uses: astral-sh/setup-uv@v6 + with: + python-version: "3.11" + + - name: Create mypy environment + run: uv venv --python 3.11 + + - name: Install mypy environment + run: | + uv pip install --python .venv/bin/python -e ".[server,cloud]" + uv pip install --python .venv/bin/python mypy types-PyYAML + + - name: Run mypy + run: .venv/bin/python -m mypy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 44c8eac5..3a745724 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,15 +13,27 @@ repos: stages: [pre-commit] - repo: local hooks: + - id: quality-gate-pre-commit + name: quality gate (pre-commit) + entry: .venv/bin/python scripts/quality_gate.py --skip-ruff + language: system + pass_filenames: false + stages: [pre-commit] - id: ruff-check-push name: ruff-check (push) - entry: uv run ruff check . + entry: .venv/bin/python -m ruff check . language: system pass_filenames: false stages: [pre-push] - id: ruff-format-check-push name: ruff-format (push) - entry: uv run ruff format --check . + entry: .venv/bin/python -m ruff format --check . + language: system + pass_filenames: false + stages: [pre-push] + - id: quality-gate-push + name: quality gate (push) + entry: .venv/bin/python scripts/quality_gate.py --skip-ruff language: system pass_filenames: false stages: [pre-push] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b8a6a78..7043542c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,10 +11,17 @@ RankLLM Contribution flow 5. Every PR should be formatted. Below are the instructions to do so: - Bootstrap the repo-local development environment with `uv python install 3.11`, `uv venv --python 3.11`, `source .venv/bin/activate`, and `uv sync --group dev` - Run the following command in the project root to set up pre-commit and pre-push hooks (all commits through git UI will automatically be formatted): `uv run pre-commit install --install-hooks --hook-type pre-commit --hook-type pre-push` - - To manually make sure your code is correctly formatted and lint-clean, run `uv run pre-commit run --all-files` - - To run Ruff directly, use `uv run ruff check .` and `uv run ruff format .` -6. Run from the root directory the unit tests with `uv run python -m unittest discover test` -7. Update the `pyproject.toml` if applicable + - Install the full local validation stack before running the ordered gate: `uv pip install --python .venv/bin/python -e '.[server,cloud]'` + - To run the full ordered gate manually, use `uv run python scripts/quality_gate.py` + - `pre-commit` and `pre-push` now enforce the same order: Ruff, then required offline tests, then MyPy. + - To run Ruff directly, use `uv run ruff check .` and `uv run ruff format --check .` +6. Run from the root directory the required offline tests before every push: + - `uv run python -m unittest discover -s test/analysis` + - `uv run python -m unittest discover -s test/evaluation` + - `uv run python -m unittest discover -s test/rerank` + - `uv run python -m unittest test.test_cli_packaging test.test_cli_scaffolding test.test_cli_rerank_command test.test_cli_validation test.test_cli_prompt test.test_cli_view test.test_cli_introspection test.test_cli_utilities test.test_cli_http test.test_cli_mcp test.test_cli_legacy_wrappers` +7. Run MyPy from the root directory with `uv run mypy` +8. Update the `pyproject.toml` if applicable ## Suggested PR Description diff --git a/pyproject.toml b/pyproject.toml index 3bbca038..4e90d038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,9 +105,11 @@ all = [ [dependency-groups] dev = [ + "mypy>=1.11.2", "pre-commit>=3.8.0", "pytest>=8.4.2", "ruff>=0.12.0", + "types-PyYAML>=6.0.12.20250516", ] [project.urls] @@ -145,6 +147,28 @@ ignore = ["E501"] "src/rank_llm/demo/*.py" = ["E402"] "src/rank_llm/scripts/*.py" = ["E402"] +[tool.mypy] +files = [ + "src/rank_llm/api", + "src/rank_llm/cli", + "src/rank_llm/data.py", + "src/rank_llm/_optional.py", +] +explicit_package_bases = true +mypy_path = "src" +check_untyped_defs = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +no_implicit_optional = true +strict_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_return_any = true +warn_unused_configs = true +follow_imports = "silent" +cache_dir = ".mypy_cache" +incremental = true + [tool.setuptools.packages.find] where = ["src"] include = [ diff --git a/scripts/quality_gate.py b/scripts/quality_gate.py new file mode 100644 index 00000000..3d820468 --- /dev/null +++ b/scripts/quality_gate.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +VENV_BIN = Path(sys.prefix) / "bin" +CLI_SMOKE_MODULES = [ + "test.test_cli_packaging", + "test.test_cli_scaffolding", + "test.test_cli_rerank_command", + "test.test_cli_validation", + "test.test_cli_prompt", + "test.test_cli_view", + "test.test_cli_introspection", + "test.test_cli_utilities", + "test.test_cli_http", + "test.test_cli_mcp", + "test.test_cli_legacy_wrappers", +] + + +def _java_env() -> dict[str, str]: + env = os.environ.copy() + env["PATH"] = f"{VENV_BIN}{os.pathsep}{env.get('PATH', '')}" + java_executable = shutil.which("java") + if java_executable is not None: + env.setdefault("JAVA_HOME", str(Path(java_executable).resolve().parents[1])) + add_modules = "--add-modules=jdk.incubator.vector" + current_options = env.get("JAVA_TOOL_OPTIONS", "") + if add_modules not in current_options.split(): + env["JAVA_TOOL_OPTIONS"] = f"{add_modules} {current_options}".strip() + return env + + +def _run_step( + name: str, + command: list[str], + *, + env: dict[str, str] | None = None, +) -> None: + print(f"[quality-gate] {name}: {' '.join(command)}", flush=True) + subprocess.run(command, cwd=REPO_ROOT, check=True, env=env) + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Run RankLLM's ordered quality gate: Ruff, tests, then mypy." + ) + parser.add_argument( + "--skip-ruff", + action="store_true", + help="Skip Ruff commands because they already ran in the current hook stage.", + ) + args = parser.parse_args() + + python = sys.executable + if not args.skip_ruff: + _run_step("ruff-check", [python, "-m", "ruff", "check", "."]) + _run_step("ruff-format", [python, "-m", "ruff", "format", "--check", "."]) + + _run_step( + "analysis-tests", + [python, "-m", "unittest", "discover", "-s", "test/analysis"], + env=_java_env(), + ) + _run_step( + "evaluation-tests", + [python, "-m", "unittest", "discover", "-s", "test/evaluation"], + env=_java_env(), + ) + _run_step( + "rerank-tests", + [python, "-m", "unittest", "discover", "-s", "test/rerank"], + env=_java_env(), + ) + _run_step( + "cli-smoke", + [python, "-m", "unittest", *CLI_SMOKE_MODULES], + env=_java_env(), + ) + _run_step("mypy", [python, "-m", "mypy"]) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/rank_llm/cli/adapters.py b/src/rank_llm/cli/adapters.py index d234985f..f8ca90ae 100644 --- a/src/rank_llm/cli/adapters.py +++ b/src/rank_llm/cli/adapters.py @@ -21,7 +21,7 @@ def make_file_artifact(name: str, path: str) -> dict[str, Any]: def serialize_data(value: Any) -> Any: - if dataclasses.is_dataclass(value): + if dataclasses.is_dataclass(value) and not isinstance(value, type): return { key: serialize_data(item) for key, item in dataclasses.asdict(value).items() } diff --git a/src/rank_llm/cli/introspection.py b/src/rank_llm/cli/introspection.py index 342158f0..b919276a 100644 --- a/src/rank_llm/cli/introspection.py +++ b/src/rank_llm/cli/introspection.py @@ -218,7 +218,7 @@ def doctor_report() -> dict[str, Any]: serve_mcp_ready = ( optional_dependencies["fastmcp"] and optional_dependencies["pyserini"] ) - command_readiness = { + command_readiness: dict[str, dict[str, Any]] = { "rerank": {"ready": True}, "evaluate": {"ready": True}, "analyze": {"ready": True}, @@ -238,7 +238,7 @@ def doctor_report() -> dict[str, Any]: } overall_status = ( "ready" - if python_ok and all(item["ready"] for item in command_readiness.values()) + if python_ok and all(bool(item["ready"]) for item in command_readiness.values()) else "degraded" ) return { diff --git a/src/rank_llm/cli/main.py b/src/rank_llm/cli/main.py index bb13b06c..4a1fd645 100644 --- a/src/rank_llm/cli/main.py +++ b/src/rank_llm/cli/main.py @@ -4,7 +4,7 @@ import json import sys from collections.abc import Sequence -from typing import Any, NoReturn +from typing import Any, NoReturn, TypeVar, overload from rank_llm.cli.adapters import make_data_artifact, serialize_data from rank_llm.cli.config import load_config @@ -38,6 +38,8 @@ from rank_llm.cli.view import ViewError, build_view_summary, render_view_summary from rank_llm.retrieve.retrieval_method import RetrievalMethod +_NamespaceT = TypeVar("_NamespaceT") + class CLIError(Exception): def __init__( @@ -64,11 +66,32 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._current_argv: list[str] = [] + @overload + def parse_args( + self, + args: Sequence[str] | None = None, + namespace: None = None, + ) -> argparse.Namespace: ... + + @overload + def parse_args( + self, + args: Sequence[str] | None, + namespace: _NamespaceT, + ) -> _NamespaceT: ... + + @overload + def parse_args( + self, + *, + namespace: _NamespaceT, + ) -> _NamespaceT: ... + def parse_args( self, args: Sequence[str] | None = None, - namespace: argparse.Namespace | None = None, - ) -> argparse.Namespace: + namespace: _NamespaceT | None = None, + ) -> argparse.Namespace | _NamespaceT: self._current_argv = list(args) if args is not None else list(sys.argv[1:]) return super().parse_args(args, namespace) @@ -525,9 +548,27 @@ def _build_error_response(error: CLIError) -> CommandResponse: def _read_direct_payload(args: argparse.Namespace) -> dict[str, Any]: try: if args.stdin: - return json.loads(sys.stdin.read()) + payload = json.loads(sys.stdin.read()) + if not isinstance(payload, dict): + raise CLIError( + "Direct rerank payload must deserialize to a JSON object.", + exit_code=EXIT_CODES["invalid_arguments"], + status="validation_error", + error_code="invalid_json", + command="rerank", + ) + return payload if args.input_json: - return json.loads(args.input_json) + payload = json.loads(args.input_json) + if not isinstance(payload, dict): + raise CLIError( + "Direct rerank payload must deserialize to a JSON object.", + exit_code=EXIT_CODES["invalid_arguments"], + status="validation_error", + error_code="invalid_json", + command="rerank", + ) + return payload except json.JSONDecodeError as exc: source = "stdin" if args.stdin else "--input-json" raise CLIError( diff --git a/src/rank_llm/cli/operations.py b/src/rank_llm/cli/operations.py index 6a189e00..9bf2dd6f 100644 --- a/src/rank_llm/cli/operations.py +++ b/src/rank_llm/cli/operations.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, cast from rank_llm.data import Candidate, Query, Request, Result from rank_llm.rerank import IdentityReranker, Reranker @@ -339,6 +339,9 @@ def run_evaluate_aggregate( if runner is None: from argparse import Namespace + from rank_llm.evaluation.trec_eval import EvalFunction + + runner = EvalFunction.eval args = Namespace( model_name=model_name, context_size=context_size, @@ -380,7 +383,10 @@ def run_response_analysis_files( capture_stdout, analyzer.count_errors, verbose ), } - return _run_with_captured_stdout(capture_stdout, runner, files, verbose) + return cast( + dict[str, Any], + _run_with_captured_stdout(capture_stdout, runner, files, verbose), + ) def run_retrieve_cache_generation( diff --git a/src/rank_llm/cli/prompt_view.py b/src/rank_llm/cli/prompt_view.py index 39f8572b..646e0e6b 100644 --- a/src/rank_llm/cli/prompt_view.py +++ b/src/rank_llm/cli/prompt_view.py @@ -2,7 +2,7 @@ from importlib.resources import files from pathlib import Path -from typing import Any +from typing import Any, cast import yaml @@ -40,7 +40,12 @@ def load_prompt_template(name_or_path: str) -> dict[str, Any]: if not path.exists(): raise PromptTemplateError(f"Unknown prompt template: {name_or_path}") with path.open("r", encoding="utf-8") as handle: - return yaml.safe_load(handle) + loaded = yaml.safe_load(handle) + if not isinstance(loaded, dict): + raise PromptTemplateError( + f"Prompt template must deserialize to an object: {path}" + ) + return cast(dict[str, Any], loaded) def build_prompt_template_view(name_or_path: str) -> dict[str, Any]: diff --git a/src/rank_llm/cli/view.py b/src/rank_llm/cli/view.py index 45ecbdcc..f8de17a7 100644 --- a/src/rank_llm/cli/view.py +++ b/src/rank_llm/cli/view.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Any +from typing import Any, cast class ViewError(Exception): @@ -60,7 +60,7 @@ def detect_artifact_type(path: Path) -> str: def load_records(path: Path, artifact_type: str) -> list[Any]: if artifact_type == "invocations-history": - return json.loads(path.read_text(encoding="utf-8")) + return cast(list[Any], json.loads(path.read_text(encoding="utf-8"))) return [ json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() @@ -96,7 +96,7 @@ def summarize_records(records: list[Any], artifact_type: str) -> dict[str, Any]: def _first_jsonl_record(path: Path) -> dict[str, Any]: for line in path.read_text(encoding="utf-8").splitlines(): if line.strip(): - return json.loads(line) + return cast(dict[str, Any], json.loads(line)) raise ViewError(f"empty jsonl file: {path}") diff --git a/src/rank_llm/data.py b/src/rank_llm/data.py index 1c563b33..c309e192 100644 --- a/src/rank_llm/data.py +++ b/src/rank_llm/data.py @@ -40,7 +40,7 @@ class InferenceInvocation: class Result: query: Query candidates: list[Candidate] = field(default_factory=list) - invocations_history: list[InferenceInvocation] = (field(default_factory=list),) + invocations_history: list[InferenceInvocation] = field(default_factory=list) @dataclass @@ -76,16 +76,20 @@ def __init__( self, data: Request | Result | list[Result] | list[Request], append: bool = False, - ): + ) -> None: if isinstance(data, list): - self._data = data + self._data: list[Request | Result] = list(data) else: self._data = [data] self._append = append - def write_inference_invocations_history(self, filename: str): + def write_inference_invocations_history(self, filename: str) -> None: aggregated_history = [] for d in self._data: + if not isinstance(d, Result): + raise TypeError( + "write_inference_invocations_history expects Result objects" + ) values = [] for info in d.invocations_history: values.append(info.__dict__) @@ -96,7 +100,7 @@ def write_inference_invocations_history(self, filename: str): output = json.dumps(aggregated_history, indent=2, ensure_ascii=False) f.write(output) - def write_in_json_format(self, filename: str): + def write_in_json_format(self, filename: str) -> None: results = [] for d in self._data: candidates = [candidate.__dict__ for candidate in d.candidates] @@ -105,7 +109,7 @@ def write_in_json_format(self, filename: str): output = json.dumps(results, indent=2, ensure_ascii=False) f.write(output) - def write_in_jsonl_format(self, filename: str): + def write_in_jsonl_format(self, filename: str) -> None: with open(filename, "a" if self._append else "w") as f: for d in self._data: candidates = [candidate.__dict__ for candidate in d.candidates] @@ -116,7 +120,7 @@ def write_in_jsonl_format(self, filename: str): f.write(output) f.write("\n") - def write_in_trec_eval_format(self, filename: str): + def write_in_trec_eval_format(self, filename: str) -> None: with open(filename, "a" if self._append else "w") as f: for d in self._data: qid = d.query.qid diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index dc0b42df..6f6cf74c 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -195,7 +195,7 @@ def get_model_coordinator(self) -> RankLLM: def create_model_coordinator( model_path: str, - default_model_coordinator: RankLLM, + default_model_coordinator: RankLLM | None, interactive: bool, **kwargs: Any, ) -> RankLLM: diff --git a/test/rerank/listwise/test_RankListwiseOSLLM.py b/test/rerank/listwise/test_RankListwiseOSLLM.py index f0fc5cd0..47243223 100644 --- a/test/rerank/listwise/test_RankListwiseOSLLM.py +++ b/test/rerank/listwise/test_RankListwiseOSLLM.py @@ -238,8 +238,8 @@ def setUp(self): self.mock_tokenizer.apply_chat_template.side_effect = ( lambda messages, **kwargs: str(messages) ) - self.mock_tokenizer.encode.side_effect = lambda x, **kwargs: [0] * ( - len(x) // 4 + 1 + self.mock_tokenizer.encode.side_effect = lambda x, **kwargs: ( + [0] * (len(x) // 4 + 1) ) # Patch the handlers used by RankListwiseOSLLM (lazy-imported inside __init__)