Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions .github/workflows/pr-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ jobs:
run: uv venv --python 3.11

- name: Install lint dependencies
run: uv pip install --python .venv/bin/python pre-commit black isort
run: uv pip install --python .venv/bin/python ruff

- name: Run lint
run: .venv/bin/pre-commit run --all-files
run: |
.venv/bin/python -m ruff check .
.venv/bin/python -m ruff format --check .

unit_tests:
needs: lint
Expand All @@ -47,6 +49,7 @@ jobs:
install-target: ".[cloud,vllm]"
env:
TMPDIR: /mnt/tmp
JAVA_TOOL_OPTIONS: --add-modules=jdk.incubator.vector
steps:
- uses: actions/checkout@v4

Expand Down Expand Up @@ -81,6 +84,7 @@ jobs:
runs-on: ubuntu-latest
env:
TMPDIR: /mnt/tmp
JAVA_TOOL_OPTIONS: --add-modules=jdk.incubator.vector
steps:
- uses: actions/checkout@v4

Expand Down Expand Up @@ -118,3 +122,35 @@ jobs:
test.test_cli_http \
test.test_cli_mcp \
test.test_cli_legacy_wrappers

mypy:
needs:
- unit_tests
- cli_smoke
runs-on: ubuntu-latest
env:
TMPDIR: /mnt/tmp
steps:
- uses: actions/checkout@v4

- name: Prepare temp directory on /mnt
run: |
sudo mkdir -p "$TMPDIR"
sudo chown "$USER":"$USER" "$TMPDIR"
df -h /
df -h /mnt

- uses: astral-sh/setup-uv@v6
with:
python-version: "3.11"

- name: Create mypy environment
run: uv venv --python 3.11

- name: Install mypy environment
run: |
uv pip install --python .venv/bin/python -e ".[server,cloud]"
uv pip install --python .venv/bin/python mypy types-PyYAML

- name: Run mypy
run: .venv/bin/python -m mypy
16 changes: 14 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,27 @@ repos:
stages: [pre-commit]
- repo: local
hooks:
- id: quality-gate-pre-commit
name: quality gate (pre-commit)
entry: .venv/bin/python scripts/quality_gate.py --skip-ruff
language: system
pass_filenames: false
stages: [pre-commit]
- id: ruff-check-push
name: ruff-check (push)
entry: uv run ruff check .
entry: .venv/bin/python -m ruff check .
language: system
pass_filenames: false
stages: [pre-push]
- id: ruff-format-check-push
name: ruff-format (push)
entry: uv run ruff format --check .
entry: .venv/bin/python -m ruff format --check .
language: system
pass_filenames: false
stages: [pre-push]
- id: quality-gate-push
name: quality gate (push)
entry: .venv/bin/python scripts/quality_gate.py --skip-ruff
language: system
pass_filenames: false
stages: [pre-push]
15 changes: 11 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,17 @@ RankLLM Contribution flow
5. Every PR should be formatted. Below are the instructions to do so:
- Bootstrap the repo-local development environment with `uv python install 3.11`, `uv venv --python 3.11`, `source .venv/bin/activate`, and `uv sync --group dev`
- Run the following command in the project root to set up pre-commit and pre-push hooks (all commits through git UI will automatically be formatted): `uv run pre-commit install --install-hooks --hook-type pre-commit --hook-type pre-push`
- To manually make sure your code is correctly formatted and lint-clean, run `uv run pre-commit run --all-files`
- To run Ruff directly, use `uv run ruff check .` and `uv run ruff format .`
6. Run from the root directory the unit tests with `uv run python -m unittest discover test`
7. Update the `pyproject.toml` if applicable
- Install the full local validation stack before running the ordered gate: `uv pip install --python .venv/bin/python -e '.[server,cloud]'`
- To run the full ordered gate manually, use `uv run python scripts/quality_gate.py`
- `pre-commit` and `pre-push` now enforce the same order: Ruff, then required offline tests, then MyPy.
- To run Ruff directly, use `uv run ruff check .` and `uv run ruff format --check .`
6. Run from the root directory the required offline tests before every push:
- `uv run python -m unittest discover -s test/analysis`
- `uv run python -m unittest discover -s test/evaluation`
- `uv run python -m unittest discover -s test/rerank`
- `uv run python -m unittest test.test_cli_packaging test.test_cli_scaffolding test.test_cli_rerank_command test.test_cli_validation test.test_cli_prompt test.test_cli_view test.test_cli_introspection test.test_cli_utilities test.test_cli_http test.test_cli_mcp test.test_cli_legacy_wrappers`
7. Run MyPy from the root directory with `uv run mypy`
8. Update the `pyproject.toml` if applicable

## Suggested PR Description

Expand Down
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ all = [

[dependency-groups]
dev = [
"mypy>=1.11.2",
"pre-commit>=3.8.0",
"pytest>=8.4.2",
"ruff>=0.12.0",
"types-PyYAML>=6.0.12.20250516",
]

[project.urls]
Expand Down Expand Up @@ -145,6 +147,28 @@ ignore = ["E501"]
"src/rank_llm/demo/*.py" = ["E402"]
"src/rank_llm/scripts/*.py" = ["E402"]

[tool.mypy]
files = [
"src/rank_llm/api",
"src/rank_llm/cli",
"src/rank_llm/data.py",
"src/rank_llm/_optional.py",
]
explicit_package_bases = true
mypy_path = "src"
check_untyped_defs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
no_implicit_optional = true
strict_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_return_any = true
warn_unused_configs = true
follow_imports = "silent"
cache_dir = ".mypy_cache"
incremental = true

[tool.setuptools.packages.find]
where = ["src"]
include = [
Expand Down
91 changes: 91 additions & 0 deletions scripts/quality_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import argparse
import os
import shutil
import subprocess
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]
VENV_BIN = Path(sys.prefix) / "bin"
CLI_SMOKE_MODULES = [
"test.test_cli_packaging",
"test.test_cli_scaffolding",
"test.test_cli_rerank_command",
"test.test_cli_validation",
"test.test_cli_prompt",
"test.test_cli_view",
"test.test_cli_introspection",
"test.test_cli_utilities",
"test.test_cli_http",
"test.test_cli_mcp",
"test.test_cli_legacy_wrappers",
]


def _java_env() -> dict[str, str]:
env = os.environ.copy()
env["PATH"] = f"{VENV_BIN}{os.pathsep}{env.get('PATH', '')}"
java_executable = shutil.which("java")
if java_executable is not None:
env.setdefault("JAVA_HOME", str(Path(java_executable).resolve().parents[1]))
add_modules = "--add-modules=jdk.incubator.vector"
current_options = env.get("JAVA_TOOL_OPTIONS", "")
if add_modules not in current_options.split():
env["JAVA_TOOL_OPTIONS"] = f"{add_modules} {current_options}".strip()
return env


def _run_step(
name: str,
command: list[str],
*,
env: dict[str, str] | None = None,
) -> None:
print(f"[quality-gate] {name}: {' '.join(command)}", flush=True)
subprocess.run(command, cwd=REPO_ROOT, check=True, env=env)


def main() -> int:
parser = argparse.ArgumentParser(
description="Run RankLLM's ordered quality gate: Ruff, tests, then mypy."
)
parser.add_argument(
"--skip-ruff",
action="store_true",
help="Skip Ruff commands because they already ran in the current hook stage.",
)
args = parser.parse_args()

python = sys.executable
if not args.skip_ruff:
_run_step("ruff-check", [python, "-m", "ruff", "check", "."])
_run_step("ruff-format", [python, "-m", "ruff", "format", "--check", "."])

_run_step(
"analysis-tests",
[python, "-m", "unittest", "discover", "-s", "test/analysis"],
env=_java_env(),
)
_run_step(
"evaluation-tests",
[python, "-m", "unittest", "discover", "-s", "test/evaluation"],
env=_java_env(),
)
_run_step(
"rerank-tests",
[python, "-m", "unittest", "discover", "-s", "test/rerank"],
env=_java_env(),
)
_run_step(
"cli-smoke",
[python, "-m", "unittest", *CLI_SMOKE_MODULES],
env=_java_env(),
)
_run_step("mypy", [python, "-m", "mypy"])
return 0


if __name__ == "__main__":
raise SystemExit(main())
2 changes: 1 addition & 1 deletion src/rank_llm/cli/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def make_file_artifact(name: str, path: str) -> dict[str, Any]:


def serialize_data(value: Any) -> Any:
if dataclasses.is_dataclass(value):
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return {
key: serialize_data(item) for key, item in dataclasses.asdict(value).items()
}
Expand Down
4 changes: 2 additions & 2 deletions src/rank_llm/cli/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def doctor_report() -> dict[str, Any]:
serve_mcp_ready = (
optional_dependencies["fastmcp"] and optional_dependencies["pyserini"]
)
command_readiness = {
command_readiness: dict[str, dict[str, Any]] = {
"rerank": {"ready": True},
"evaluate": {"ready": True},
"analyze": {"ready": True},
Expand All @@ -238,7 +238,7 @@ def doctor_report() -> dict[str, Any]:
}
overall_status = (
"ready"
if python_ok and all(item["ready"] for item in command_readiness.values())
if python_ok and all(bool(item["ready"]) for item in command_readiness.values())
else "degraded"
)
return {
Expand Down
51 changes: 46 additions & 5 deletions src/rank_llm/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import sys
from collections.abc import Sequence
from typing import Any, NoReturn
from typing import Any, NoReturn, TypeVar, overload

from rank_llm.cli.adapters import make_data_artifact, serialize_data
from rank_llm.cli.config import load_config
Expand Down Expand Up @@ -38,6 +38,8 @@
from rank_llm.cli.view import ViewError, build_view_summary, render_view_summary
from rank_llm.retrieve.retrieval_method import RetrievalMethod

_NamespaceT = TypeVar("_NamespaceT")


class CLIError(Exception):
def __init__(
Expand All @@ -64,11 +66,32 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._current_argv: list[str] = []

@overload
def parse_args(
self,
args: Sequence[str] | None = None,
namespace: None = None,
) -> argparse.Namespace: ...

@overload
def parse_args(
self,
args: Sequence[str] | None,
namespace: _NamespaceT,
) -> _NamespaceT: ...

@overload
def parse_args(
self,
*,
namespace: _NamespaceT,
) -> _NamespaceT: ...

def parse_args(
self,
args: Sequence[str] | None = None,
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
namespace: _NamespaceT | None = None,
) -> argparse.Namespace | _NamespaceT:
self._current_argv = list(args) if args is not None else list(sys.argv[1:])
return super().parse_args(args, namespace)

Expand Down Expand Up @@ -525,9 +548,27 @@ def _build_error_response(error: CLIError) -> CommandResponse:
def _read_direct_payload(args: argparse.Namespace) -> dict[str, Any]:
try:
if args.stdin:
return json.loads(sys.stdin.read())
payload = json.loads(sys.stdin.read())
if not isinstance(payload, dict):
raise CLIError(
"Direct rerank payload must deserialize to a JSON object.",
exit_code=EXIT_CODES["invalid_arguments"],
status="validation_error",
error_code="invalid_json",
command="rerank",
)
return payload
if args.input_json:
return json.loads(args.input_json)
payload = json.loads(args.input_json)
if not isinstance(payload, dict):
raise CLIError(
"Direct rerank payload must deserialize to a JSON object.",
exit_code=EXIT_CODES["invalid_arguments"],
status="validation_error",
error_code="invalid_json",
command="rerank",
)
return payload
except json.JSONDecodeError as exc:
source = "stdin" if args.stdin else "--input-json"
raise CLIError(
Expand Down
10 changes: 8 additions & 2 deletions src/rank_llm/cli/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, cast

from rank_llm.data import Candidate, Query, Request, Result
from rank_llm.rerank import IdentityReranker, Reranker
Expand Down Expand Up @@ -339,6 +339,9 @@ def run_evaluate_aggregate(
if runner is None:
from argparse import Namespace

from rank_llm.evaluation.trec_eval import EvalFunction

runner = EvalFunction.eval
args = Namespace(
model_name=model_name,
context_size=context_size,
Expand Down Expand Up @@ -380,7 +383,10 @@ def run_response_analysis_files(
capture_stdout, analyzer.count_errors, verbose
),
}
return _run_with_captured_stdout(capture_stdout, runner, files, verbose)
return cast(
dict[str, Any],
_run_with_captured_stdout(capture_stdout, runner, files, verbose),
)


def run_retrieve_cache_generation(
Expand Down
Loading
Loading