Skip to content

Commit 80c49b4

Browse files
committed
Add plugin implementation for conservation eval during training
1 parent 1fa3a22 commit 80c49b4

File tree

8 files changed

+396
-190
lines changed

8 files changed

+396
-190
lines changed

experiments/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def default_train(
386386
data_seed=train_config.data_seed,
387387
eval_harness_steps=train_config.steps_per_task_eval or 10000,
388388
eval_harness=harness_config,
389+
eval_plugins=train_config.eval_plugins,
389390
)
390391

391392
# Create the pod config

experiments/plantcad/evaluation.py

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import json
2323
import dataclasses
2424
from dataclasses import dataclass
25-
from typing import Any
2625
from collections.abc import Callable
2726
from datasets import Dataset
2827

@@ -38,7 +37,6 @@
3837
from huggingface_hub import HfApi
3938
from transformers import AutoModelForCausalLM
4039
from levanter.callbacks import StepInfo
41-
from levanter.utils.tree_utils import inference_mode
4240
from marin.utilities.json_encoder import CustomJsonEncoder
4341

4442
from experiments.plantcad.utils import get_available_gpus, get_nucleotide_token_ids, get_plantcad_tokenizer
@@ -48,21 +46,12 @@
4846

4947

5048
@dataclass
51-
class DnaEvalConfig:
52-
"""Configuration for DNA model evolutionary conservation evaluation"""
53-
54-
checkpoint_path: str | InputName
55-
"""Path to the model checkpoint directory"""
49+
class DnaEvalBaseConfig:
50+
"""Base configuration for DNA evaluation with fields needed for training callbacks"""
5651

5752
model_config: str
5853
"""Model configuration size (e.g., '300m', '100m', etc.)"""
5954

60-
device: str = "cuda"
61-
"""Device to use for model inference (e.g., 'cuda', 'cpu')"""
62-
63-
dtype: str | None = None
64-
"""Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
65-
6655
dataset_path: str = "plantcad/evolutionary-constraint-example"
6756
"""Dataset repository path"""
6857

@@ -75,15 +64,29 @@ class DnaEvalConfig:
7564
batch_size: int = 32
7665
"""Batch size to use for inference"""
7766

78-
num_workers: int | None = None
79-
"""Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
80-
8167
max_samples: int | None = None
8268
"""Maximum number of samples to evaluate (for quick testing)"""
8369

84-
random_seed: int = versioned(42)
70+
random_seed: int = 42
8571
"""Random seed for data shuffling prior to downsampling"""
8672

73+
74+
@dataclass
75+
class DnaEvalConfig(DnaEvalBaseConfig):
76+
"""Configuration for standalone DNA model evolutionary conservation evaluation"""
77+
78+
checkpoint_path: str | InputName | None = None
79+
"""Path to the model checkpoint directory (None for training callbacks)"""
80+
81+
device: str = "cuda"
82+
"""Device to use for model inference (e.g., 'cuda', 'cpu')"""
83+
84+
dtype: str | None = None
85+
"""Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
86+
87+
num_workers: int | None = None
88+
"""Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
89+
8790
revision: str = versioned("0.1")
8891
"""Revision number to force re-runs when needed"""
8992

@@ -406,6 +409,7 @@ def score_eval_dataset(
406409
eval_dataset: Dataset,
407410
logit_function: Callable[[TokenArray], LogitArray],
408411
batch_size: int = 32,
412+
log_progress: bool = True,
409413
) -> ConservationResult:
410414
"""Score evaluation dataset based on zero-shot conservation prediction."""
411415

@@ -420,7 +424,8 @@ def score_eval_dataset(
420424
batches = eval_dataset.with_format(None).batch(batch_size=batch_size)
421425
total_batches = len(batches)
422426
progress_interval = max(1, total_batches // 20) # Every 5%
423-
logger.info(f"Processing {len(eval_dataset)} samples in {total_batches} batches (batch_size={batch_size})")
427+
if log_progress:
428+
logger.info(f"Processing {len(eval_dataset)} samples in {total_batches} batches (batch_size={batch_size})")
424429

425430
for batch_index, batch_data in enumerate(batches):
426431
# Tokenize sequences
@@ -451,7 +456,7 @@ def score_eval_dataset(
451456
total_processed += len(sequences)
452457

453458
# Log progress every 5% of batches
454-
if batch_index % progress_interval == 0 or batch_index == total_batches - 1:
459+
if log_progress and (batch_index % progress_interval == 0 or batch_index == total_batches - 1):
455460
progress_pct = ((batch_index + 1) / total_batches) * 100
456461
logger.info(
457462
f"Progress: {batch_index + 1}/{total_batches} batches ({progress_pct:.1f}%) - "
@@ -466,44 +471,7 @@ def score_eval_dataset(
466471
# ------------------------------------------------------------------------------------------------
467472

468473

469-
def evaluate_dna_conservation(
470-
tokenizer: AutoTokenizer,
471-
logit_function: Callable[[Any], Any],
472-
eval_dataset: Dataset,
473-
batch_size: int = 32,
474-
step: int | None = None,
475-
) -> dict[str, float]:
476-
"""
477-
Core evaluation logic - works for both training callbacks and standalone evaluation.
478-
479-
Args:
480-
logit_function: Function that takes tokens and returns logits
481-
eval_dataset: HuggingFace dataset with 'seq' field and binary 'label' field
482-
batch_size: Batch size for evaluation
483-
step: Training step (for logging), None for standalone
484-
485-
Returns:
486-
Dictionary with evaluation metrics including ROC AUC
487-
"""
488-
# Collect scores and labels using shared function
489-
result = score_eval_dataset(
490-
tokenizer=tokenizer, logit_function=logit_function, eval_dataset=eval_dataset, batch_size=batch_size
491-
)
492-
493-
# Calculate metrics using shared function
494-
results = evaluate_conservation_scores(result)
495-
496-
# Log during training, log for standalone
497-
if step is not None:
498-
levanter.tracker.log({"eval/dna_conservation/roc": results["roc_auc"]}, step=step)
499-
logger.info(f"Step {step}: ROC AUC = {results['roc_auc']:.3f}")
500-
else:
501-
logger.info(f"ROC AUC = {results['roc_auc']:.4f} ({results['n_total']} valid nucleotides)")
502-
503-
return results
504-
505-
506-
def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None]:
474+
def create_dna_eval_callback(config: DnaEvalBaseConfig) -> Callable[[StepInfo], None]:
507475
"""Create a training callback for DNA evaluation."""
508476

509477
# Load tokenizer
@@ -514,25 +482,39 @@ def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None
514482
dataset = load_eval_dataset(config)
515483

516484
def dna_conservation_callback(step_info: StepInfo) -> None:
517-
# Put model in inference mode
518-
eval_model = inference_mode(step_info.state.model, True)
485+
logger.info(f"Running PlantCAD DNA conservation evaluation (step={step_info.step})")
486+
eval_model = step_info.state.eval_model
519487

520488
# Create logit function for Levanter model
521489
def logit_function(
522490
tokens: ht.Int[ht.NamedArray, "batch position"],
523491
) -> ht.Float[ht.NamedArray, "batch position vocab"]:
524-
# TODO: validate input / output types
525492
return eval_model(tokens)
526493

527-
# Run evaluation
528-
evaluate_dna_conservation(
494+
# Compute scores with binary labels
495+
scores = score_eval_dataset(
529496
tokenizer=tokenizer,
530497
logit_function=logit_function,
531-
eval_dataset=dataset, # Use the loaded dataset
498+
eval_dataset=dataset,
532499
batch_size=config.batch_size,
500+
log_progress=False,
501+
)
502+
503+
# Evaluate scores and labels
504+
metrics = evaluate_conservation_scores(scores)
505+
506+
# Log results
507+
levanter.tracker.log(
508+
{
509+
"eval/dna_conservation_roc": metrics["roc_auc"],
510+
},
533511
step=step_info.step,
534512
)
535513

514+
logger.info(
515+
f"PlantCAD evaluation complete: ROC AUC = {metrics['roc_auc']:.4f}, " f"n_samples = {metrics['n_total']}"
516+
)
517+
536518
return dna_conservation_callback
537519

538520

@@ -614,7 +596,11 @@ def logit_function(
614596

615597
# Generate raw conservation scores and labels
616598
result = score_eval_dataset(
617-
tokenizer=tokenizer, logit_function=logit_function, eval_dataset=dataset, batch_size=config.batch_size
599+
tokenizer=tokenizer,
600+
logit_function=logit_function,
601+
eval_dataset=dataset,
602+
batch_size=config.batch_size,
603+
log_progress=True,
618604
)
619605

620606
logger.info(f"Generated {len(result.scores)} conservation scores")
@@ -640,10 +626,7 @@ def evaluate_conservation_scores(scores: ConservationResult) -> dict[str, float]
640626
if len(scores.scores) == 0:
641627
raise ValueError("No valid conservation scores found")
642628

643-
# Log total before filtering and filter out NaN scores
644629
n_unmasked_total = len(scores.scores)
645-
logger.info(f"n_unmasked_total: {n_unmasked_total}")
646-
647630
valid_mask = ~np.isnan(scores.scores)
648631
filtered_scores = np.array(scores.scores)[valid_mask]
649632
filtered_labels = np.array(scores.labels)[valid_mask]

0 commit comments

Comments
 (0)