Skip to content

Commit b928795

Browse files
committed
Remove the cuda dependency for CI tests
1 parent 0a4dbb2 commit b928795

2 files changed

Lines changed: 10 additions & 2 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ jobs:
2626
- name: Check Python version
2727
run: python --version
2828

29+
- name: Install CPU torch before test deps
30+
run: pip install torch --index-url https://download.pytorch.org/whl/cpu
31+
2932
- name: Install dependencies
3033
run: python -m pip install ".[lint, test]"
3134

pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = [
1515
"tensorflow-cpu==2.18.0",
1616
"tensorflow-probability",
1717
"getdist",
18-
"jax[cuda12]",
18+
"jax",
1919
"jax_dataloader",
2020
"datasets==3.5.0",
2121
"optax",
@@ -31,6 +31,10 @@ requires-python = ">=3.10, <3.13"
3131
urls = {Repository = "https://github.com/sachaguer/jaxili"}
3232

3333
[project.optional-dependencies]
34+
cuda = [
35+
"jax[cuda12]"
36+
]
37+
3438
docs = [
3539
"myst-parser",
3640
"numpydoc",
@@ -48,13 +52,14 @@ test = [
4852
"tf-keras",
4953
"sbibm",
5054
"sbi==0.23.3",
55+
"torch"
5156
]
5257

5358
lint = ["black", "isort"]
5459
release = ["build", "twine"]
5560

5661
#Install for development
57-
dev = ["jaxili[docs,lint,release,test]"]
62+
dev = ["jaxili[docs,lint,release,test,cuda]"]
5863

5964
[tool.pydocstyle]
6065
convention = "numpy"

0 commit comments

Comments
 (0)