Skip to content

Commit bea6e67

Browse files
CompRhysfacebook-github-bot
authored andcommitted
Add pre-commit hook to Ax (#3082)
Summary: Mirrors PR to botorch: pytorch/botorch#2632 applying the same solution to ensure that there is eventual consistency to the source of truth in `requirements-fmt.txt`. Fixes flake8 errors that occured when running ```python (base) ➜ Ax git:(pre-commit) ✗ pre-commit run --all-files [INFO] Installing environment for https://github.com/pycqa/flake8. [INFO] Once installed this environment will be reused. [INFO] This may take a few minutes... Check pre-commit formatting versions.....................................Passed Format files with µfmt...................................................Passed flake8...................................................................Failed - hook id: flake8 - exit code: 1 ax/benchmark/tests/problems/test_mixed_integer_problems.py:60:23: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:70:23: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:76:29: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:77:29: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:78:29: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:84:23: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:90:29: E226 missing whitespace around arithmetic operator ax/benchmark/tests/problems/test_mixed_integer_problems.py:91:29: E226 missing whitespace around arithmetic operator ax/modelbridge/tests/test_prediction_utils.py:165:39: E226 missing whitespace around arithmetic operator ax/service/tests/test_global_stopping.py:50:43: E226 missing whitespace around arithmetic operator scripts/insert_api_refs.py:47:35: E226 missing whitespace around arithmetic operator ``` Pull Request resolved: #3082 Reviewed By: sdaulton Differential Revision: D66190310 Pulled By: Balandat fbshipit-source-id: 05ce2786cf15baf3554f7d5f8e835e17e1258ed4
1 parent 4127865 commit bea6e67

File tree

9 files changed

+194
-38
lines changed

9 files changed

+194
-38
lines changed

.github/workflows/build-and-test.yml

-25
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,7 @@ jobs:
1515
pinned_botorch: false
1616
secrets: inherit
1717

18-
lint:
19-
20-
runs-on: ubuntu-latest
21-
22-
steps:
23-
- uses: actions/checkout@v4
24-
- name: Set up Python
25-
uses: actions/setup-python@v5
26-
with:
27-
python-version: "3.10"
28-
- name: Install dependencies
29-
# Pin ufmt deps so they match intermal pyfmt.
30-
run: |
31-
pip install -r requirements-fmt.txt
32-
pip install flake8
33-
- name: ufmt
34-
run: |
35-
ufmt diff .
36-
- name: Flake8
37-
# run even if previous step (ufmt) failed
38-
if: ${{ always() }}
39-
run: |
40-
flake8
41-
4218
docs:
43-
4419
runs-on: ubuntu-latest
4520

4621
steps:

.github/workflows/lint.yml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: Lint
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
workflow_dispatch:
9+
10+
11+
jobs:
12+
13+
lint:
14+
runs-on: ubuntu-latest
15+
steps:
16+
- uses: actions/checkout@v4
17+
18+
- name: Set up Python
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: "3.10"
22+
23+
- name: Install dependencies
24+
run: pip install pre-commit
25+
26+
- name: Run pre-commit
27+
run: pre-commit run --all-files --show-diff-on-failure

.pre-commit-config.yaml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
repos:
2+
- repo: local
3+
hooks:
4+
- id: check-requirements-versions
5+
name: Check pre-commit formatting versions
6+
entry: python scripts/check_pre_commit_reqs.py
7+
language: python
8+
always_run: true
9+
pass_filenames: false
10+
additional_dependencies:
11+
- PyYAML
12+
13+
- repo: https://github.com/omnilib/ufmt
14+
rev: v2.8.0
15+
hooks:
16+
- id: ufmt
17+
additional_dependencies:
18+
- black==24.4.2
19+
- usort==1.0.8.post1
20+
- ruff-api==0.1.0
21+
- stdlibs==2024.1.28
22+
args: [format]
23+
24+
- repo: https://github.com/pycqa/flake8
25+
rev: 7.0.0
26+
hooks:
27+
- id: flake8

CONTRIBUTING.md

+30-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,36 @@ We use the [`Ruff` code formatter](https://docs.astral.sh/ruff/formatter/) for a
3535
### Unit Tests
3636
The majority of our code is covered by unit tests and we are working to get to 100% code coverage. Please ensure that new code is covered by unit tests. To run all unit tests, we recommend installing pytest using `pip install pytest` and running `pytest -ra` from the root of the Ax repo. To get coverage, `pip install pytest-cov` and run `pytest -ra --cov=ax`.
3737

38-
### Linting
39-
Run the linter via `flake8` (`pip install flake8`) from the root of the Ax repository. Note that we have a [custom flake8 configuration](https://github.com/facebook/Ax/blob/main/.flake8).
38+
#### Code Style
39+
40+
Ax uses [ufmt](https://github.com/omnilib/ufmt) to enforce consistent code
41+
formatting (based on [black](https://github.com/ambv/black)) and import sorting
42+
(based on [µsort](https://github.com/facebook/usort)) across the code base.
43+
Install via `pip install ufmt`, and auto-format and auto-sort by running
44+
45+
```bash
46+
ufmt format .
47+
```
48+
49+
from the repository root.
50+
51+
#### Flake8 linting
52+
53+
Ax uses `flake8` for linting. To run the linter locally, install `flake8`
54+
via `pip install flake8`, and then run
55+
56+
```bash
57+
flake8 .
58+
```
59+
60+
from the repository root.
61+
62+
#### Pre-commit hooks
63+
64+
Contributors can use [pre-commit](https://pre-commit.com/) to run `ufmt` and
65+
`flake8` as part of the commit process. To install the hooks, install `pre-commit`
66+
via `pip install pre-commit` and run `pre-commit install` from the repository
67+
root.
4068

4169
### Static Type Checking
4270
We use [Pyre](https://pyre-check.org/) for static type checking and require code to be fully type annotated. At the moment, static type checking is not supported within Travis.

ax/benchmark/tests/problems/test_mixed_integer_problems.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_problems(self) -> None:
5757
cases: list[tuple[BenchmarkProblem, dict[str, float], torch.Tensor]] = [
5858
(
5959
get_discrete_hartmann(),
60-
{f"x{i+1}": 0.0 for i in range(6)},
60+
{f"x{i + 1}": 0.0 for i in range(6)},
6161
torch.zeros(6, dtype=torch.double),
6262
),
6363
(
@@ -67,28 +67,28 @@ def test_problems(self) -> None:
6767
),
6868
(
6969
get_discrete_ackley(),
70-
{f"x{i+1}": 0.0 for i in range(13)},
70+
{f"x{i + 1}": 0.0 for i in range(13)},
7171
torch.zeros(13, dtype=torch.double),
7272
),
7373
(
7474
get_discrete_ackley(),
7575
{
76-
**{f"x{i+1}": 2 for i in range(0, 5)},
77-
**{f"x{i+1}": 4 for i in range(5, 10)},
78-
**{f"x{i+1}": 1.0 for i in range(10, 13)},
76+
**{f"x{i + 1}": 2 for i in range(0, 5)},
77+
**{f"x{i + 1}": 4 for i in range(5, 10)},
78+
**{f"x{i + 1}": 1.0 for i in range(10, 13)},
7979
},
8080
torch.ones(13, dtype=torch.double),
8181
),
8282
(
8383
get_discrete_rosenbrock(),
84-
{f"x{i+1}": 0.0 for i in range(10)},
84+
{f"x{i + 1}": 0.0 for i in range(10)},
8585
torch.full((10,), -5.0, dtype=torch.double),
8686
),
8787
(
8888
get_discrete_rosenbrock(),
8989
{
90-
**{f"x{i+1}": 3 for i in range(0, 6)},
91-
**{f"x{i+1}": 1.0 for i in range(6, 10)},
90+
**{f"x{i + 1}": 3 for i in range(0, 6)},
91+
**{f"x{i + 1}": 1.0 for i in range(6, 10)},
9292
},
9393
torch.full((10,), 10.0, dtype=torch.double),
9494
),

ax/modelbridge/tests/test_prediction_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,5 @@ def _attach_completed_trials(ax_client: AxClient) -> None:
162162

163163
# Test metric evaluation method
164164
def _evaluate_test_metrics(parameters: TParameterization) -> TEvaluationOutcome:
165-
x = np.array([parameters.get(f"x{i+1}") for i in range(2)])
165+
x = np.array([parameters.get(f"x{i + 1}") for i in range(2)])
166166
return {"test_metric1": (x[0] / x[1], 0.0), "test_metric2": (x[0] + x[1], 0.0)}

ax/service/tests/test_global_stopping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_ax_client_for_branin(
4747

4848
def evaluate(self, parameters: TParameterization) -> dict[str, tuple[float, float]]:
4949
"""Evaluates the parameters for branin experiment."""
50-
x = np.array([parameters.get(f"x{i+1}") for i in range(2)])
50+
x = np.array([parameters.get(f"x{i + 1}") for i in range(2)])
5151
# pyre-fixme[7]: Expected `Dict[str, Tuple[float, float]]` but got
5252
# `Dict[str, Tuple[Union[float, ndarray], float]]`.
5353
return {"branin": (branin(x), 0.0)}

scripts/check_pre_commit_reqs.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
from pathlib import Path
9+
10+
import yaml
11+
12+
13+
def parse_requirements(filepath):
14+
"""Parse requirements file and return a dict of package versions."""
15+
versions = {}
16+
with open(filepath) as f:
17+
for line in f:
18+
line = line.strip()
19+
if line and not line.startswith("#"):
20+
# Handle different requirement formats
21+
if "==" in line:
22+
pkg, version = line.split("==")
23+
versions[pkg.strip().lower()] = version.strip()
24+
return versions
25+
26+
27+
def parse_precommit_config(filepath):
28+
"""Parse pre-commit config and extract ufmt repo rev and hook dependencies."""
29+
with open(filepath) as f:
30+
config = yaml.safe_load(f)
31+
32+
versions = {}
33+
for repo in config["repos"]:
34+
if "https://github.com/omnilib/ufmt" in repo.get("repo", ""):
35+
# Get ufmt version from rev - assumes fixed format: vX.Y.Z
36+
versions["ufmt"] = repo.get("rev", "").replace("v", "")
37+
38+
# Get dependency versions
39+
for hook in repo["hooks"]:
40+
if hook["id"] == "ufmt":
41+
for dep in hook.get("additional_dependencies", []):
42+
if "==" in dep:
43+
pkg, version = dep.split("==")
44+
versions[pkg.strip().lower()] = version.strip()
45+
break
46+
return versions
47+
48+
49+
def main():
50+
# Find the pre-commit config and requirements files
51+
config_file = Path(".pre-commit-config.yaml")
52+
requirements_file = Path("requirements-fmt.txt")
53+
54+
if not config_file.exists():
55+
print(f"Error: Could not find {config_file}")
56+
sys.exit(1)
57+
58+
if not requirements_file.exists():
59+
print(f"Error: Could not find {requirements_file}")
60+
sys.exit(1)
61+
62+
# Parse both files
63+
req_versions = parse_requirements(requirements_file)
64+
config_versions = parse_precommit_config(config_file)
65+
66+
# Check versions
67+
mismatches = []
68+
for pkg, req_ver in req_versions.items():
69+
req_ver = req_versions.get(pkg, None)
70+
config_ver = config_versions.get(pkg, None)
71+
72+
if req_ver != config_ver:
73+
found_version_str = f"{pkg}: {requirements_file} has {req_ver},"
74+
if pkg == "ufmt":
75+
mismatches.append(
76+
f"{found_version_str} pre-commit config rev has v{config_ver}"
77+
)
78+
else:
79+
mismatches.append(
80+
f"{found_version_str} pre-commit config has {config_ver}"
81+
)
82+
83+
# Report results
84+
if mismatches:
85+
msg_str = "".join("\n\t" + msg for msg in mismatches)
86+
print(
87+
f"Version mismatches found:{msg_str}"
88+
"\nPlease update the versions in `.pre-commit-config.yaml` to be "
89+
"consistent with those in `requirements-fmt.txt` (source of truth)."
90+
"\nNote: all versions must be pinned exactly ('==X.Y.Z') in both files."
91+
)
92+
sys.exit(1)
93+
else:
94+
print("All versions match!")
95+
sys.exit(0)
96+
97+
98+
if __name__ == "__main__":
99+
main()

scripts/insert_api_refs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def replace_backticks(source_path, docs_path):
4444
for i, l in enumerate(lines):
4545
match = re.search(pattern, l)
4646
if match:
47-
print(f"{f}:{i+1} s/{match.group(0)}/{link}")
47+
print(f"{f}:{i + 1} s/{match.group(0)}/{link}")
4848
lines[i] = re.sub(pattern, link, l)
4949
open(f, "w").writelines(lines)
5050

0 commit comments

Comments
 (0)