Skip to content

Commit 1fa3a22

Browse files
committed
Rewrite conservation evals using jax/haliax
1 parent db28f28 commit 1fa3a22

19 files changed

+2065
-549
lines changed

experiments/defaults.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def default_train(
260260
tags: Any additional tags to add to the Wandb tracker.
261261
use_default_validation: Whether to use the default validation sets (currently Paloma).
262262
eval_harness_tasks: List of evaluation harness tasks. Defaults to the CORE set of tasks. Use () or [] to disable
263-
shuffle: Whether to shuffle the training data. True=full shuffle, False=no shuffle, int=era shuffle with that length.
263+
shuffle: Whether to shuffle the training data. True=full shuffle, False=no shuffle,
264+
int=era shuffle with that length.
264265
"""
265266

266267
pretraining_data = _prepare_data_config(tokenized, use_default_validation, shuffle=shuffle)

experiments/plantcad/README.md

Lines changed: 39 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
PlantCAD1 reproduction experiments.
44

5-
## Setup
6-
75
Original tutorial: https://gist.github.com/eric-czech/31e5b79689d322f7becb94a109ce0b75
86

7+
## Setup
8+
99
### Local
1010

1111
```bash
@@ -15,7 +15,7 @@ uv venv --python 3.11
1515
uv sync
1616
```
1717

18-
### SkyPilot
18+
### Remote (SkyPilot)
1919

2020
```bash
2121
sky api stop; [ -d ~/.sky ] && rm -rf ~/.sky
@@ -43,10 +43,10 @@ EOF
4343
uv pip install "skypilot[lambda]==0.10.3"
4444
sky check lambda
4545
sky launch \
46-
--cluster marin --infra lambda --num-nodes 1 --gpus "A100:8" --disk-size 100 \
46+
--cluster marin --infra lambda --num-nodes 1 --gpus "A10:1" --disk-size 100 \
4747
--env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
4848
output/cluster.sky.yaml --retry-until-up --yes
49-
rsync -rPz ./ marin:/home/ubuntu/sky_workdir --exclude '.venv' --exclude '.git' --exclude src/marin/markdown
49+
REMOTE_USER=ubuntu
5050
```
5151

5252
#### GCP
@@ -59,7 +59,7 @@ sky launch \
5959
--instance-type a2-highgpu-1g --region us-east-1 \
6060
--env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
6161
output/cluster.sky.yaml
62-
rsync -rPz ./ marin:/home/gcpuser/sky_workdir --exclude '.venv' --exclude '.git' --exclude src/marin/markdown
62+
REMOTE_USER=gcpuser
6363
```
6464

6565
#### CoreWeave
@@ -72,7 +72,7 @@ sky launch \
7272
--cpus 124 --memory 2008 \
7373
--env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
7474
output/cluster.sky.yaml
75-
rsync -rPz ./ marin:/home/sky/sky_workdir --exclude '.venv' --exclude '.git' --exclude src/marin/markdown --exclude '__pycache__'
75+
REMOTE_USER=sky
7676

7777
# For transformer-engine-jax:
7878
sudo apt update
@@ -81,39 +81,56 @@ sudo apt install build-essential g++ cmake ninja-build
8181
# hint: This error likely indicates that you need to install a library that provides "cuda_runtime_api.h" for `transformer-engine-jax@2.6.0.post1`
8282
```
8383

84-
#### Run
84+
## Execution
8585

8686
```bash
8787
ssh marin
8888
cd sky_workdir && conda deactivate && source .venv/bin/activate
8989
export RAY_DEBUG=legacy
9090

91-
python -m experiments.plantcad.exp_pc1_tutorial --prefix local_store --force_run_failed true
92-
python -m experiments.plantcad.exp_pc1_batch_tune --prefix local_store --force_run_failed true
91+
# Code sync
92+
rsync -rPz ./ marin:/home/$REMOTE_USER/sky_workdir \
93+
--exclude '.venv' --exclude '.git' \
94+
--exclude src/marin/markdown --exclude '__pycache__'
9395

94-
python -m experiments.plantcad.exp_pc1_lr_tune --prefix local_store --force_run_failed true
96+
# Experiments and tuning
97+
python -m experiments.plantcad.scripts.exp_pc1_tutorial --prefix local_store --force_run_failed true
98+
python -m experiments.plantcad.scripts.exp_pc1_batch_tune --prefix local_store --force_run_failed true
99+
python -m experiments.plantcad.scripts.exp_pc1_lr_tune --prefix local_store --force_run_failed true
95100
find local_store | grep -E 'step-668$' | xargs -I {} echo "hf upload plantcad/_dev_marin_plantcad1_v1_lr_tune {} {} --repo-type model"
96101

102+
# Training
97103
mkdir -p logs
98104
screen -S train
99-
python -m experiments.plantcad.exp_pc1_train \
105+
python -m experiments.plantcad.scripts.exp_pc1_train \
100106
--prefix local_store --force_run_failed true 2>&1 | tee logs/exp_pc1_train.log
101-
# https://wandb.ai/eric-czech/marin/runs/plantcad-train-300m-r01-2aa671
102-
#
103-
104-
python -m experiments.plantcad.exp_pc1_eval --prefix local_store --force_run_failed true
105107

106-
sky exec -c marin output/task.sky.yaml
108+
# Evaluation
109+
rm -rf local_store/evaluation/dna-conservation*; python -m experiments.plantcad.scripts.exp_pc1_eval --prefix local_store --force_run_failed true
107110
```
108111

109-
## EDA
112+
```bash
113+
> python -m experiments.plantcad.misc.agg_eval_results
114+
roc_auc step checkpoint_path
115+
0.535217 1673 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-1673
116+
0.546725 3346 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-3346
117+
0.549917 5019 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-5019
118+
0.558042 6692 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-6692
119+
0.560290 8365 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-8365
120+
0.565785 10038 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-10038
121+
0.570048 11711 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-11711
122+
0.576358 13384 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-13384
123+
0.583593 15057 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-15057
124+
0.585834 16730 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-16730
125+
0.589215 18403 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-18403
126+
0.588738 20076 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-20076
127+
0.593178 21749 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-21749
128+
```
110129

111-
### Tokenizer stats
130+
## EDA
112131

113-
From https://huggingface.co/kuleshov-group/PlantCaduceus_l20, e.g.:
114-
PlantCaduceus vocab size: 7
132+
Stats on kuleshov-group/Angiosperm_16_genomes:
115133

116-
### Dataset stats
117134
```
118135
> python count_dataset.py
119136
Number of examples: 5,485,282
@@ -137,61 +154,3 @@ Most common tokens:
137154
Total unique tokens: 5
138155
Token ID range: 2 - 6
139156
```
140-
141-
This means 2,808,464,384 / 20 ==> ~140.4M params is Chinchilla optimal for text.
142-
143-
## TODO
144-
145-
- Look for prefetch config
146-
- Debug: "Your setup doesn't support bf16/gpu." in eval with `bf16_full_eval`
147-
148-
```
149-
# cat /tmp/ray/session_2025-09-20_04-11-22_232072_15326/runtime_resources/pip/f20b7e798eeb2fc9320b1a708aaeee4e0130ee14/virtualenv/lib/python3.11/site-packages/transformers/training_args.py | grep -i "doesn't support" -C 100
150-
if self.bf16 or self.bf16_full_eval:
151-
if self.use_cpu and not is_torch_available() and not is_torch_xla_available():
152-
# cpu
153-
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
154-
elif not self.use_cpu:
155-
if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support
156-
error_message = "Your setup doesn't support bf16/gpu."
157-
if is_torch_cuda_available():
158-
error_message += " You need Ampere+ GPU with cuda>=11.0"
159-
# gpu
160-
raise ValueError(error_message)
161-
```
162-
163-
- Discuss: `levanter.data.loader - loader.py:258 - INFO :: Prefetch wasn't fast enough: 33.836.`
164-
- Discuss this:
165-
166-
```
167-
# TODO: discuss https://github.com/jax-ml/jax/issues/24909
168-
# (train_lm_task pid=31054) /tmp/ray/session_2025-09-16_11-16-05_933535_22116/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/levanter/layers/attention.py:428: UserWarning: transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention.. Falling back to the reference implementation.
169-
# (train_lm_task pid=31054) warnings.warn(f"{msg}. Falling back to the reference implementation.")
170-
# (train_lm_task pid=31054) E0916 11:23:11.594742 31054 buffer_comparator.cc:150] Difference at 10780: 16.375, expected 14.5
171-
# (train_lm_task pid=31054) E0916 11:23:11.594787 31054 buffer_comparator.cc:150] Difference at 10942: 17.25, expected 15.25
172-
# (train_lm_task pid=31054) E0916 11:23:11.594791 31054 buffer_comparator.cc:150] Difference at 11042: 17, expected 15.1875
173-
# (train_lm_task pid=31054) E0916 11:23:11.594795 31054 buffer_comparator.cc:150] Difference at 11132: 16.875, expected 14.8125
174-
# (train_lm_task pid=31054) E0916 11:23:11.594801 31054 buffer_comparator.cc:150] Difference at 12211: 15, expected 16.875
175-
# (train_lm_task pid=31054) E0916 11:23:11.594804 31054 buffer_comparator.cc:150] Difference at 12212: 14.625, expected 16.625
176-
# (train_lm_task pid=31054) E0916 11:23:11.594807 31054 buffer_comparator.cc:150] Difference at 12235: 14.75, expected 16.625
177-
# (train_lm_task pid=31054) E0916 11:23:11.594809 31054 buffer_comparator.cc:150] Difference at 12276: 15.0625, expected 16.875
178-
# (train_lm_task pid=31054) E0916 11:23:11.594812 31054 buffer_comparator.cc:150] Difference at 12327: 14.5, expected 16.25
179-
# (train_lm_task pid=31054) E0916 11:23:11.594815 31054 buffer_comparator.cc:150] Difference at 12336: 15.5625, expected 17.5
180-
# (train_lm_task pid=31054) 2025-09-16 11:23:11.594824: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1070] Results do not match the reference. This is likely a bug/unexpected loss of precision.
181-
```
182-
183-
- Discuss these constant errors in deleting checkpoints:
184-
185-
```
186-
train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:383 - INFO :: Saved checkpoint to local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20076 for step 20076
187-
(train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:230 - INFO :: Deleting old temporary checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061 after saving new checkpoint.
188-
(train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:262 - INFO :: Removing checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
189-
(train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:270 - INFO :: Deleting old checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061 from /home/sky/sky_workdir/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
190-
(train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:276 - ERROR :: Failed to delete checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
191-
(train_lm_task pid=295292) Traceback (most recent call last):
192-
(train_lm_task pid=295292) File "/tmp/ray/session_2025-09-26_10-17-29_045489_276828/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/levanter/checkpoint.py", line 272, in _do_rm_checkpoint
193-
(train_lm_task pid=295292) fs.rm(cp_path, recursive=True)
194-
(train_lm_task pid=295292) File "/tmp/ray/session_2025-09-26_10-17-29_045489_276828/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/fsspec/implementations/local.py", line 191, in rm
195-
(train_lm_task pid=295292) os.remove(p)
196-
(train_lm_task pid=295292) FileNotFoundError: [Errno 2] No such file or directory: '/home/sky/sky_workdir/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061'
197-
```

0 commit comments

Comments
 (0)