@@ -7,7 +7,7 @@ authors = [
77]
88dependencies = [
99
10- "torch==2.4 .1",
10+ "torch>=2.9 .1",
1111 "hydra-core>=1.3.2",
1212 "hydra-submitit-launcher>=1.2.0",
1313 "wandb>=0.17.6",
@@ -23,9 +23,9 @@ dependencies = [
2323 "transformers>=4.44.0",
2424 "datasets>=2.21.0",
2525 # Jax-related dependencies:
26- "jax==0.4.33 ",
27- "jaxlib==0.4.33 ",
28- "torch-jax-interop>=0.0.7 ",
26+ "jax",
27+ "flax ",
28+ "torch-jax-interop>=0.0.8 ",
2929 "gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering",
3030 "rejax>=0.1.0",
3131 "xtils[jitpp] @ git+https://github.com/jessefarebro/xtils",
@@ -34,10 +34,10 @@ dependencies = [
3434 "hydra-colorlog>=1.2.0",
3535 "remote-slurm-executor",
3636 "hydra-auto-schema>=0.0.7",
37- "hydra-orion-sweeper>=1.6.4 ; python_full_version < '3.11' ",
37+ "hydra-orion-sweeper>=1.6.4",
3838]
3939readme = "README.md"
40- requires-python = ">= {{python_version}}"
40+ requires-python = ">= {{python_version}},< 3 .14 "
4141
4242[dependency-groups]
4343dev = [
@@ -53,11 +53,11 @@ dev = [
5353 "pytest-xdist>=3.6.1",
5454 "pytest>=8.3.2",
5555 "ruff>=0.6.0",
56- "tensor-regression>=0.0.8 ",
56+ "tensor-regression>=0.1.2 ",
5757]
5858
5959[project.optional-dependencies]
60- gpu = ["jax[cuda12]>=0.4.31 ; sys_platform == 'linux'"]
60+ gpu = ["jax[cuda13] ; sys_platform == 'linux'"]
6161
6262[tool.pytest.ini_options]
6363testpaths = ["{{module_name}}", "tests"]
@@ -82,9 +82,6 @@ lint.select = ["E4", "E7", "E9", "F", "I", "UP"]
8282[tool.uv]
8383managed = true
8484
85- [tool.uv.sources]
86- remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
87-
8885
8986[build-system]
9087requires = ["hatchling", "uv-dynamic-versioning"]
@@ -98,3 +95,17 @@ packages = ["{{module_name}}"]
9895
9996[tool.hatch.version]
10097source = "uv-dynamic-versioning"
98+
99+ [[tool.uv.index]]
100+ name = "pytorch-cu130"
101+ url = "https://download.pytorch.org/whl/cu130"
102+ explicit = true
103+
104+ [tool.uv.sources]
105+ remote-slurm-executor = { git = "https://github.com/lebrice/remote-slurm-executor", branch = "master" }
106+ torch = [
107+ { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
108+ ]
109+ torchvision = [
110+ { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
111+ ]
0 commit comments