2222import json
2323import dataclasses
2424from dataclasses import dataclass
25- from typing import Any
2625from collections .abc import Callable
2726from datasets import Dataset
2827
3837from huggingface_hub import HfApi
3938from transformers import AutoModelForCausalLM
4039from levanter .callbacks import StepInfo
41- from levanter .utils .tree_utils import inference_mode
4240from marin .utilities .json_encoder import CustomJsonEncoder
4341
4442from experiments .plantcad .utils import get_available_gpus , get_nucleotide_token_ids , get_plantcad_tokenizer
4846
4947
5048@dataclass
51- class DnaEvalConfig :
52- """Configuration for DNA model evolutionary conservation evaluation"""
53-
54- checkpoint_path : str | InputName
55- """Path to the model checkpoint directory"""
49+ class DnaEvalBaseConfig :
50+ """Base configuration for DNA evaluation with fields needed for training callbacks"""
5651
5752 model_config : str
5853 """Model configuration size (e.g., '300m', '100m', etc.)"""
5954
60- device : str = "cuda"
61- """Device to use for model inference (e.g., 'cuda', 'cpu')"""
62-
63- dtype : str | None = None
64- """Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
65-
6655 dataset_path : str = "plantcad/evolutionary-constraint-example"
6756 """Dataset repository path"""
6857
@@ -75,15 +64,29 @@ class DnaEvalConfig:
7564 batch_size : int = 32
7665 """Batch size to use for inference"""
7766
78- num_workers : int | None = None
79- """Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
80-
8167 max_samples : int | None = None
8268 """Maximum number of samples to evaluate (for quick testing)"""
8369
84- random_seed : int = versioned ( 42 )
70+ random_seed : int = 42
8571 """Random seed for data shuffling prior to downsampling"""
8672
73+
74+ @dataclass
75+ class DnaEvalConfig (DnaEvalBaseConfig ):
76+ """Configuration for standalone DNA model evolutionary conservation evaluation"""
77+
78+ checkpoint_path : str | InputName | None = None
79+ """Path to the model checkpoint directory (None for training callbacks)"""
80+
81+ device : str = "cuda"
82+ """Device to use for model inference (e.g., 'cuda', 'cpu')"""
83+
84+ dtype : str | None = None
85+ """Dtype to use for model inference (e.g., 'float32', 'float16', 'bfloat16' or any torch dtype)"""
86+
87+ num_workers : int | None = None
88+ """Number of workers to use for parallel evaluation (defaults to number of GPUs if None)"""
89+
8790 revision : str = versioned ("0.1" )
8891 """Revision number to force re-runs when needed"""
8992
@@ -406,6 +409,7 @@ def score_eval_dataset(
406409 eval_dataset : Dataset ,
407410 logit_function : Callable [[TokenArray ], LogitArray ],
408411 batch_size : int = 32 ,
412+ log_progress : bool = True ,
409413) -> ConservationResult :
410414 """Score evaluation dataset based on zero-shot conservation prediction."""
411415
@@ -420,7 +424,8 @@ def score_eval_dataset(
420424 batches = eval_dataset .with_format (None ).batch (batch_size = batch_size )
421425 total_batches = len (batches )
422426 progress_interval = max (1 , total_batches // 20 ) # Every 5%
423- logger .info (f"Processing { len (eval_dataset )} samples in { total_batches } batches (batch_size={ batch_size } )" )
427+ if log_progress :
428+ logger .info (f"Processing { len (eval_dataset )} samples in { total_batches } batches (batch_size={ batch_size } )" )
424429
425430 for batch_index , batch_data in enumerate (batches ):
426431 # Tokenize sequences
@@ -451,7 +456,7 @@ def score_eval_dataset(
451456 total_processed += len (sequences )
452457
453458 # Log progress every 5% of batches
454- if batch_index % progress_interval == 0 or batch_index == total_batches - 1 :
459+ if log_progress and ( batch_index % progress_interval == 0 or batch_index == total_batches - 1 ) :
455460 progress_pct = ((batch_index + 1 ) / total_batches ) * 100
456461 logger .info (
457462 f"Progress: { batch_index + 1 } /{ total_batches } batches ({ progress_pct :.1f} %) - "
@@ -466,44 +471,7 @@ def score_eval_dataset(
466471# ------------------------------------------------------------------------------------------------
467472
468473
469- def evaluate_dna_conservation (
470- tokenizer : AutoTokenizer ,
471- logit_function : Callable [[Any ], Any ],
472- eval_dataset : Dataset ,
473- batch_size : int = 32 ,
474- step : int | None = None ,
475- ) -> dict [str , float ]:
476- """
477- Core evaluation logic - works for both training callbacks and standalone evaluation.
478-
479- Args:
480- logit_function: Function that takes tokens and returns logits
481- eval_dataset: HuggingFace dataset with 'seq' field and binary 'label' field
482- batch_size: Batch size for evaluation
483- step: Training step (for logging), None for standalone
484-
485- Returns:
486- Dictionary with evaluation metrics including ROC AUC
487- """
488- # Collect scores and labels using shared function
489- result = score_eval_dataset (
490- tokenizer = tokenizer , logit_function = logit_function , eval_dataset = eval_dataset , batch_size = batch_size
491- )
492-
493- # Calculate metrics using shared function
494- results = evaluate_conservation_scores (result )
495-
496- # Log during training, log for standalone
497- if step is not None :
498- levanter .tracker .log ({"eval/dna_conservation/roc" : results ["roc_auc" ]}, step = step )
499- logger .info (f"Step { step } : ROC AUC = { results ['roc_auc' ]:.3f} " )
500- else :
501- logger .info (f"ROC AUC = { results ['roc_auc' ]:.4f} ({ results ['n_total' ]} valid nucleotides)" )
502-
503- return results
504-
505-
506- def create_dna_eval_callback (config : DnaEvalConfig ) -> Callable [[StepInfo ], None ]:
474+ def create_dna_eval_callback (config : DnaEvalBaseConfig ) -> Callable [[StepInfo ], None ]:
507475 """Create a training callback for DNA evaluation."""
508476
509477 # Load tokenizer
@@ -514,25 +482,39 @@ def create_dna_eval_callback(config: DnaEvalConfig) -> Callable[[StepInfo], None
514482 dataset = load_eval_dataset (config )
515483
516484 def dna_conservation_callback (step_info : StepInfo ) -> None :
517- # Put model in inference mode
518- eval_model = inference_mode ( step_info .state .model , True )
485+ logger . info ( f"Running PlantCAD DNA conservation evaluation (step= { step_info . step } )" )
486+ eval_model = step_info .state .eval_model
519487
520488 # Create logit function for Levanter model
521489 def logit_function (
522490 tokens : ht .Int [ht .NamedArray , "batch position" ],
523491 ) -> ht .Float [ht .NamedArray , "batch position vocab" ]:
524- # TODO: validate input / output types
525492 return eval_model (tokens )
526493
527- # Run evaluation
528- evaluate_dna_conservation (
494+ # Compute scores with binary labels
495+ scores = score_eval_dataset (
529496 tokenizer = tokenizer ,
530497 logit_function = logit_function ,
531- eval_dataset = dataset , # Use the loaded dataset
498+ eval_dataset = dataset ,
532499 batch_size = config .batch_size ,
500+ log_progress = False ,
501+ )
502+
503+ # Evaluate scores and labels
504+ metrics = evaluate_conservation_scores (scores )
505+
506+ # Log results
507+ levanter .tracker .log (
508+ {
509+ "eval/dna_conservation_roc" : metrics ["roc_auc" ],
510+ },
533511 step = step_info .step ,
534512 )
535513
514+ logger .info (
515+ f"PlantCAD evaluation complete: ROC AUC = { metrics ['roc_auc' ]:.4f} , " f"n_samples = { metrics ['n_total' ]} "
516+ )
517+
536518 return dna_conservation_callback
537519
538520
@@ -614,7 +596,11 @@ def logit_function(
614596
615597 # Generate raw conservation scores and labels
616598 result = score_eval_dataset (
617- tokenizer = tokenizer , logit_function = logit_function , eval_dataset = dataset , batch_size = config .batch_size
599+ tokenizer = tokenizer ,
600+ logit_function = logit_function ,
601+ eval_dataset = dataset ,
602+ batch_size = config .batch_size ,
603+ log_progress = True ,
618604 )
619605
620606 logger .info (f"Generated { len (result .scores )} conservation scores" )
@@ -640,10 +626,7 @@ def evaluate_conservation_scores(scores: ConservationResult) -> dict[str, float]
640626 if len (scores .scores ) == 0 :
641627 raise ValueError ("No valid conservation scores found" )
642628
643- # Log total before filtering and filter out NaN scores
644629 n_unmasked_total = len (scores .scores )
645- logger .info (f"n_unmasked_total: { n_unmasked_total } " )
646-
647630 valid_mask = ~ np .isnan (scores .scores )
648631 filtered_scores = np .array (scores .scores )[valid_mask ]
649632 filtered_labels = np .array (scores .labels )[valid_mask ]
0 commit comments