Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ To install dependencies with uv, run:
```bash
git clone git@github.com:allenai/olmoearth_pretrain.git
cd olmoearth_pretrain
uv sync --locked --all-groups --python 3.12
uv sync --locked --extra torch-cu128 --all-groups --python 3.12
# or uv sync --locked --extra torch-cpu --all-groups --python 3.12
# only necessary for development
uv tool install pre-commit --with pre-commit-uv --force-reinstall
```
Expand Down
32 changes: 26 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ dependencies = [
"wandb>=0.19.0",
]

optional-dependencies.torch-cpu = [
"torch>=2.7,<2.10",
"torchvision>=0.22.1,<1",
]
optional-dependencies.torch-cu128 = [
"pytorch-triton",
"torch>=2.7,<2.10",
"torchvision>=0.22.1,<1",
"flash-attn>=2.4.0",
]

[build-system]
requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -56,13 +67,26 @@ addopts = [
"--import-mode=importlib",
]

[tool.uv]
conflicts = [
[
{ extra = "torch-cpu" },
{ extra = "torch-cu128" },
],
]

[tool.uv.sources]
claymodel = { git = "https://github.com/Clay-foundation/model.git" }
geobench = { git = "https://github.com/ServiceNow/geo-bench.git" }
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cpu", extra="torch-cpu"},
{ index = "pytorch-cu128", extra="torch-cu128"},
]
torchvision = [
{ index = "pytorch-cpu", extra = "torch-cpu" },
{ index = "pytorch-cu128", extra = "torch-cu128" },
]
pytorch-triton = [ { index = "pytorch-cu128", extra = "torch-cu128" } ]

[[tool.uv.index]]
name = "pytorch-cpu"
Expand Down Expand Up @@ -128,10 +152,6 @@ eval = [
"xarray>=2025.10.1",
]

flash_attn = [
"flash-attn>=2.4.0",
]

[tool.uv.extra-build-dependencies]
flash-attn = [{ requirement = "torch", match-runtime = true }]

Expand Down
Loading