Skip to content

Commit ceb3088

Browse files
committed
Remove support for JAX 0.4 and 0.5. Fix for 0.6.
- Fix installation of jax in CI - Display JAX version in CI
1 parent e371094 commit ceb3088

4 files changed

Lines changed: 15 additions & 10 deletions

File tree

.github/workflows/ci.yml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ jobs:
2929
matrix:
3030
python:
3131
- version: '3.10'
32-
jax: '>=0.4,<0.5'
32+
jax: '>=0.6,<0.7'
3333
- version: '3.11'
34-
jax: '>=0.4,<0.5'
34+
jax: '>=0.7,<0.8'
3535
- version: '3.12'
36-
jax: '>=0.5,<0.6'
36+
jax: '>=0.7,<0.8'
3737
- version: '3.13'
38-
jax: '>=0.5,<0.6'
38+
jax: '>=0.8,<0.9'
3939
- version: '3.14'
40-
jax: '>=0.8'
40+
jax: '>=0.9'
4141

4242
steps:
4343
- name: Checkout code
@@ -63,7 +63,10 @@ jobs:
6363
uv pip install "jax${{ matrix.python.jax }}"
6464
6565
- name: Run unit tests
66-
run: uv run pytest -v --cov=zeropybench --cov-report=xml tests
66+
run: |
67+
source .venv/bin/activate
68+
python -c "import jax; print(f'JAX version: {jax.__version__}')"
69+
pytest -v --cov=zeropybench --cov-report=xml tests
6770
6871
- name: Install system dependencies
6972
run: |
@@ -72,7 +75,9 @@ jobs:
7275
7376
- name: Run notebook tests
7477
if: matrix.python.version == '3.14'
75-
run: uv run pytest -v docs
78+
run: |
79+
source .venv/bin/activate
80+
pytest -v docs
7681
7782
- name: Upload coverage to Codecov
7883
if: matrix.python.version == '3.12'

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ docs = [
4949
'pillow>=9.0.0',
5050
'visu-hlo>=0.1.0',
5151
]
52-
cuda12 = ['jax[cuda12]>=0.4']
52+
cuda12 = ['jax[cuda12]>=0.6']
5353
cuda13 = ['jax[cuda13]>=0.7; python_version>="3.11"']
5454

5555
[tool.coverage.report]

src/zeropybench/_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _compile_jax(self, globals: dict[str, Any]) -> tuple[str, float, bool]:
213213
compilation_time = time.perf_counter() - start_time
214214

215215
hlo = compiled.as_text()
216-
is_single_array = isinstance(compiled.out_info, jax.ShapeDtypeStruct)
216+
is_single_array = isinstance(lowered.out_info, jax.ShapeDtypeStruct)
217217
return hlo, compilation_time, is_single_array
218218

219219
def _run_many_times(

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)