22
33PlantCAD1 reproduction experiments.
44
5- ## Setup
6-
75Original tutorial: https://gist.github.com/eric-czech/31e5b79689d322f7becb94a109ce0b75
86
7+ ## Setup
8+
99### Local
1010
1111``` bash
@@ -15,7 +15,7 @@ uv venv --python 3.11
1515uv sync
1616```
1717
18- ### SkyPilot
18+ ### Remote ( SkyPilot)
1919
2020``` bash
2121sky api stop; [ -d ~ /.sky ] && rm -rf ~ /.sky
4343uv pip install " skypilot[lambda]==0.10.3"
4444sky check lambda
4545sky launch \
46- --cluster marin --infra lambda --num-nodes 1 --gpus " A100:8 " --disk-size 100 \
46+ --cluster marin --infra lambda --num-nodes 1 --gpus " A10:1 " --disk-size 100 \
4747 --env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
4848 output/cluster.sky.yaml --retry-until-up --yes
49- rsync -rPz ./ marin:/home/ ubuntu/sky_workdir --exclude ' .venv ' --exclude ' .git ' --exclude src/marin/markdown
49+ REMOTE_USER= ubuntu
5050```
5151
5252#### GCP
@@ -59,7 +59,7 @@ sky launch \
5959 --instance-type a2-highgpu-1g --region us-east-1 \
6060 --env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
6161 output/cluster.sky.yaml
62- rsync -rPz ./ marin:/home/ gcpuser/sky_workdir --exclude ' .venv ' --exclude ' .git ' --exclude src/marin/markdown
62+ REMOTE_USER= gcpuser
6363```
6464
6565#### CoreWeave
@@ -72,7 +72,7 @@ sky launch \
7272 --cpus 124 --memory 2008 \
7373 --env HUGGING_FACE_HUB_TOKEN --env WANDB_API_KEY \
7474 output/cluster.sky.yaml
75- rsync -rPz ./ marin:/home/ sky/sky_workdir --exclude ' .venv ' --exclude ' .git ' --exclude src/marin/markdown --exclude ' __pycache__ '
75+ REMOTE_USER= sky
7676
7777# For transformer-engine-jax:
7878sudo apt update
@@ -81,39 +81,56 @@ sudo apt install build-essential g++ cmake ninja-build
8181# hint: This error likely indicates that you need to install a library that provides "cuda_runtime_api.h" for `transformer-engine-jax@2.6.0.post1`
8282```
8383
84- #### Run
84+ ## Execution
8585
8686``` bash
8787ssh marin
8888cd sky_workdir && conda deactivate && source .venv/bin/activate
8989export RAY_DEBUG=legacy
9090
91- python -m experiments.plantcad.exp_pc1_tutorial --prefix local_store --force_run_failed true
92- python -m experiments.plantcad.exp_pc1_batch_tune --prefix local_store --force_run_failed true
91+ # Code sync
92+ rsync -rPz ./ marin:/home/$REMOTE_USER /sky_workdir \
93+ --exclude ' .venv' --exclude ' .git' \
94+ --exclude src/marin/markdown --exclude ' __pycache__'
9395
94- python -m experiments.plantcad.exp_pc1_lr_tune --prefix local_store --force_run_failed true
96+ # Experiments and tuning
97+ python -m experiments.plantcad.scripts.exp_pc1_tutorial --prefix local_store --force_run_failed true
98+ python -m experiments.plantcad.scripts.exp_pc1_batch_tune --prefix local_store --force_run_failed true
99+ python -m experiments.plantcad.scripts.exp_pc1_lr_tune --prefix local_store --force_run_failed true
95100find local_store | grep -E ' step-668$' | xargs -I {} echo " hf upload plantcad/_dev_marin_plantcad1_v1_lr_tune {} {} --repo-type model"
96101
102+ # Training
97103mkdir -p logs
98104screen -S train
99- python -m experiments.plantcad.exp_pc1_train \
105+ python -m experiments.plantcad.scripts. exp_pc1_train \
100106 --prefix local_store --force_run_failed true 2>&1 | tee logs/exp_pc1_train.log
101- # https://wandb.ai/eric-czech/marin/runs/plantcad-train-300m-r01-2aa671
102- #
103-
104- python -m experiments.plantcad.exp_pc1_eval --prefix local_store --force_run_failed true
105107
106- sky exec -c marin output/task.sky.yaml
108+ # Evaluation
109+ rm -rf local_store/evaluation/dna-conservation* ; python -m experiments.plantcad.scripts.exp_pc1_eval --prefix local_store --force_run_failed true
107110```
108111
109- ## EDA
112+ ``` bash
113+ > python -m experiments.plantcad.misc.agg_eval_results
114+ roc_auc step checkpoint_path
115+ 0.535217 1673 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-1673
116+ 0.546725 3346 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-3346
117+ 0.549917 5019 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-5019
118+ 0.558042 6692 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-6692
119+ 0.560290 8365 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-8365
120+ 0.565785 10038 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-10038
121+ 0.570048 11711 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-11711
122+ 0.576358 13384 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-13384
123+ 0.583593 15057 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-15057
124+ 0.585834 16730 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-16730
125+ 0.589215 18403 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-18403
126+ 0.588738 20076 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-20076
127+ 0.593178 21749 hf://plantcad/_dev_marin_plantcad1_v1_train/local_store/checkpoints/plantcad-train-300m-r02-432442/hf/step-21749
128+ ```
110129
111- ### Tokenizer stats
130+ ## EDA
112131
113- From https://huggingface.co/kuleshov-group/PlantCaduceus_l20 , e.g.:
114- PlantCaduceus vocab size: 7
132+ Stats on kuleshov-group/Angiosperm_16_genomes:
115133
116- ### Dataset stats
117134```
118135> python count_dataset.py
119136Number of examples: 5,485,282
@@ -137,61 +154,3 @@ Most common tokens:
137154Total unique tokens: 5
138155Token ID range: 2 - 6
139156```
140-
141- This means 2,808,464,384 / 20 ==> ~ 140.4M params is Chinchilla optimal for text.
142-
143- ## TODO
144-
145- - Look for prefetch config
146- - Debug: "Your setup doesn't support bf16/gpu." in eval with ` bf16_full_eval `
147-
148- ```
149- # cat /tmp/ray/session_2025-09-20_04-11-22_232072_15326/runtime_resources/pip/f20b7e798eeb2fc9320b1a708aaeee4e0130ee14/virtualenv/lib/python3.11/site-packages/transformers/training_args.py | grep -i "doesn't support" -C 100
150- if self.bf16 or self.bf16_full_eval:
151- if self.use_cpu and not is_torch_available() and not is_torch_xla_available():
152- # cpu
153- raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
154- elif not self.use_cpu:
155- if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support
156- error_message = "Your setup doesn't support bf16/gpu."
157- if is_torch_cuda_available():
158- error_message += " You need Ampere+ GPU with cuda>=11.0"
159- # gpu
160- raise ValueError(error_message)
161- ```
162-
163- - Discuss: ` levanter.data.loader - loader.py:258 - INFO :: Prefetch wasn't fast enough: 33.836. `
164- - Discuss this:
165-
166- ```
167- # TODO: discuss https://github.com/jax-ml/jax/issues/24909
168- # (train_lm_task pid=31054) /tmp/ray/session_2025-09-16_11-16-05_933535_22116/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/levanter/layers/attention.py:428: UserWarning: transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention.. Falling back to the reference implementation.
169- # (train_lm_task pid=31054) warnings.warn(f"{msg}. Falling back to the reference implementation.")
170- # (train_lm_task pid=31054) E0916 11:23:11.594742 31054 buffer_comparator.cc:150] Difference at 10780: 16.375, expected 14.5
171- # (train_lm_task pid=31054) E0916 11:23:11.594787 31054 buffer_comparator.cc:150] Difference at 10942: 17.25, expected 15.25
172- # (train_lm_task pid=31054) E0916 11:23:11.594791 31054 buffer_comparator.cc:150] Difference at 11042: 17, expected 15.1875
173- # (train_lm_task pid=31054) E0916 11:23:11.594795 31054 buffer_comparator.cc:150] Difference at 11132: 16.875, expected 14.8125
174- # (train_lm_task pid=31054) E0916 11:23:11.594801 31054 buffer_comparator.cc:150] Difference at 12211: 15, expected 16.875
175- # (train_lm_task pid=31054) E0916 11:23:11.594804 31054 buffer_comparator.cc:150] Difference at 12212: 14.625, expected 16.625
176- # (train_lm_task pid=31054) E0916 11:23:11.594807 31054 buffer_comparator.cc:150] Difference at 12235: 14.75, expected 16.625
177- # (train_lm_task pid=31054) E0916 11:23:11.594809 31054 buffer_comparator.cc:150] Difference at 12276: 15.0625, expected 16.875
178- # (train_lm_task pid=31054) E0916 11:23:11.594812 31054 buffer_comparator.cc:150] Difference at 12327: 14.5, expected 16.25
179- # (train_lm_task pid=31054) E0916 11:23:11.594815 31054 buffer_comparator.cc:150] Difference at 12336: 15.5625, expected 17.5
180- # (train_lm_task pid=31054) 2025-09-16 11:23:11.594824: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1070] Results do not match the reference. This is likely a bug/unexpected loss of precision.
181- ```
182-
183- - Discuss these constant errors in deleting checkpoints:
184-
185- ```
186- train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:383 - INFO :: Saved checkpoint to local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20076 for step 20076
187- (train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:230 - INFO :: Deleting old temporary checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061 after saving new checkpoint.
188- (train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:262 - INFO :: Removing checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
189- (train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:270 - INFO :: Deleting old checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061 from /home/sky/sky_workdir/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
190- (train_lm_task pid=295292) 2025-09-26T19:19:35 - 0 - levanter.checkpoint - checkpoint.py:276 - ERROR :: Failed to delete checkpoint local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061
191- (train_lm_task pid=295292) Traceback (most recent call last):
192- (train_lm_task pid=295292) File "/tmp/ray/session_2025-09-26_10-17-29_045489_276828/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/levanter/checkpoint.py", line 272, in _do_rm_checkpoint
193- (train_lm_task pid=295292) fs.rm(cp_path, recursive=True)
194- (train_lm_task pid=295292) File "/tmp/ray/session_2025-09-26_10-17-29_045489_276828/runtime_resources/pip/96e8d2e31c1b75b4d19a0ea2c755a672438fdca3/virtualenv/lib/python3.11/site-packages/fsspec/implementations/local.py", line 191, in rm
195- (train_lm_task pid=295292) os.remove(p)
196- (train_lm_task pid=295292) FileNotFoundError: [Errno 2] No such file or directory: '/home/sky/sky_workdir/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/local_store/checkpoints/plantcad-train-300m-r02-432442/checkpoints/step-20061'
197- ```
0 commit comments