Skip to content

Commit e93e5c5

Browse files
authored
Don't require jax[cuda12] as the dep, just jax (#2)
Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 18357dc commit e93e5c5

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

poetry.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ readme = "README.md"
88

99
[tool.poetry.dependencies]
1010
python = "^3.10"
11-
jax = {extras = ["cuda12"], version = "^0.4.28"}
11+
jax = "^0.4.28"
1212
torch = "^2.0.0"
1313
pytorch2jax = "^0.1.0"
1414
flax = "^0.8.4"
@@ -22,6 +22,7 @@ pre-commit = "^3.7.1"
2222
pytest-testmon = "^2.1.1"
2323
pytest-env = "^1.1.3"
2424
tensor-regression = "^0.0.4"
25+
jax = {extras = ["cuda12"], version = "^0.4.28"}
2526

2627

2728
[tool.poetry-dynamic-versioning]

0 commit comments

Comments
 (0)