Skip to content

Commit c3592a1

Browse files
committed
add finetune flag
1 parent e472d28 commit c3592a1

3 files changed

Lines changed: 18 additions & 5 deletions

File tree

olmoearth_pretrain/internal/all_evals.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def load_user_module(path: str) -> Any:
8080
num_workers=4,
8181
pooling_type=PoolingType.MEAN,
8282
norm_stats_from_pretrained=False,
83+
norm_method=NormMethod.NORM_NO_CLIP_2_STD,
8384
eval_interval=Duration.epochs(5),
8485
eval_mode=EvalMode.KNN,
8586
),
@@ -117,6 +118,7 @@ def load_user_module(path: str) -> Any:
117118
num_workers=2,
118119
pooling_type=PoolingType.MEAN,
119120
norm_stats_from_pretrained=False,
121+
norm_method=NormMethod.NORM_NO_CLIP_2_STD,
120122
probe_lr=0.1,
121123
eval_interval=Duration.epochs(10),
122124
eval_mode=EvalMode.LINEAR_PROBE,
@@ -127,7 +129,8 @@ def load_user_module(path: str) -> Any:
127129
probe_batch_size=8,
128130
num_workers=2,
129131
pooling_type=PoolingType.MEAN,
130-
norm_stats_from_pretrained=True,
132+
norm_stats_from_pretrained=False,
133+
norm_method=NormMethod.NORM_NO_CLIP_2_STD,
131134
probe_lr=0.1,
132135
eval_interval=Duration.epochs(10),
133136
eval_mode=EvalMode.LINEAR_PROBE,
@@ -139,6 +142,7 @@ def load_user_module(path: str) -> Any:
139142
num_workers=8,
140143
pooling_type=PoolingType.MEAN,
141144
norm_stats_from_pretrained=False,
145+
norm_method=NormMethod.NORM_NO_CLIP_2_STD,
142146
probe_lr=0.01,
143147
eval_interval=Duration.epochs(10),
144148
eval_mode=EvalMode.LINEAR_PROBE,
@@ -149,7 +153,7 @@ def load_user_module(path: str) -> Any:
149153
probe_batch_size=128,
150154
num_workers=4,
151155
pooling_type=PoolingType.MEAN,
152-
norm_stats_from_pretrained=False,
156+
norm_stats_from_pretrained=True,
153157
probe_lr=0.1,
154158
eval_interval=Duration.epochs(10),
155159
eval_mode=EvalMode.LINEAR_PROBE,

olmoearth_pretrain/internal/full_eval_sweep_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Launch fine-tune evaluation sweeps for OlmoEarth Pretrain checkpoints.
1+
"""Launch fine-tune evaluation sweeps for OlmoEarth and other models.
22
33
Example run:
44
python olmoearth_pretrain/internal/full_eval_sweep_finetune.py --project_name 2025_10_08_phase2_finetune --module_path olmoearth_pretrain/evals/models/clay/clay_launch.py --cluster ai2/titan --model_name clay --clay --defaults_only

scripts/get_max_eval_metrics_from_wandb.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
MODELS_WITH_MULTIPLE_SIZES,
1313
BaselineModelName,
1414
)
15-
from olmoearth_pretrain.internal.all_evals import EVAL_TASKS
15+
from olmoearth_pretrain.internal.all_evals import EVAL_TASKS, FT_EVAL_TASKS
1616
from olmoearth_pretrain.train.callbacks.evaluator_callback import EvalMode
1717

1818
WANDB_ENTITY = "eai-ai2"
19-
METRICS = EVAL_TASKS.keys()
19+
METRICS = list(EVAL_TASKS.keys())
2020

2121
# Dataset partitions to consider (excluding default)
2222
PARTITIONS = [
@@ -379,9 +379,18 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
379379
action="store_true",
380380
help="Report test metrics based on the configuration of the validation results witht the highest score",
381381
)
382+
parser.add_argument(
383+
"--finetune",
384+
action="store_true",
385+
help="Use finetune evaluation tasks when determining metrics",
386+
)
382387

383388
args = parser.parse_args()
384389

390+
global METRICS
391+
selected_tasks = FT_EVAL_TASKS if args.finetune else EVAL_TASKS
392+
METRICS = list(selected_tasks.keys())
393+
385394
if args.per_partition:
386395
if not args.run_prefix:
387396
parser.error("--per-partition requires run_prefix to be specified")

0 commit comments

Comments
 (0)