Skip to content

Commit 33c5dc6

Browse files
committed
Add test for CLM score parity to biofoundation
1 parent 2537367 commit 33c5dc6

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

experiments/plantcad/evaluation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,12 @@ def create_alternate_sequences(
292292
assert 0 <= ref_cts.max().item() <= 1
293293
if (invalid := ref_cts == 0).any().item():
294294
pos = nucleotide_positions[Batch, invalid]
295-
tok = tokens_expanded[Batch, invalid, Position, pos]
295+
tok = tokens_expanded[Batch, invalid][Position, pos]
296296
raise ValueError(
297-
"Found invalid sequences in batch with OOV nucleotides at target positions; "
298-
f"Target positions: {pos} "
299-
f"Valid nucleotide token IDs: {nucleotide_token_ids} "
300-
f"Invalid tokens: {tok} "
297+
"Found invalid sequences in batch with OOV nucleotides at target positions;\n"
298+
f"Target positions: {pos.array} \n"
299+
f"Valid nucleotide token IDs: {nucleotide_token_ids} \n"
300+
f"Invalid tokens: {tok.array} "
301301
)
302302
ref = hax.argmax(ref_mask, axis=Variant)
303303
assert ref.axes == (Batch,)

experiments/plantcad/tests/test_evaluation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import numpy as np
1517
import pytest
1618
import jax
1719
import jax.numpy as jnp
1820
import haliax as hax
21+
from datasets import load_dataset
22+
from huggingface_hub import snapshot_download
23+
from levanter.models.llama import LlamaConfig
24+
from levanter.utils.jax_utils import use_cpu_device
25+
from transformers import PretrainedConfig as HfConfig, AutoTokenizer
1926
from experiments.plantcad.evaluation import (
2027
create_alternate_sequences,
2128
compute_sequence_logprob,
@@ -277,6 +284,68 @@ def test_compute_causal_conservation():
277284
)
278285

279286

287+
def test_compute_causal_conservation_accuracy():
288+
"""End-to-end parity test against reference scores.
289+
290+
Reference scores come from https://github.com/Open-Athena/biofoundation/commit/23f6745defdd54cac09b43c066f249789bf74d56
291+
"""
292+
# Download model and dataset
293+
data_path = snapshot_download(
294+
repo_id="plantcad/ci",
295+
repo_type="dataset",
296+
allow_patterns="unit_tests/evolutionary_constraint/ref_logprob_clm_sim/*",
297+
)
298+
ds = load_dataset("plantcad/ci", name="ut_ec_ref_logprob_clm_sim", split="train")
299+
model_dir = os.path.join(data_path, "unit_tests/evolutionary_constraint/ref_logprob_clm_sim/model")
300+
301+
# Load tokenizer and config
302+
hf_config = HfConfig.from_pretrained(model_dir)
303+
config = LlamaConfig.from_hf_config(hf_config)
304+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
305+
306+
# Load sequences and positions
307+
sequences = ds["seq"] if "seq" in ds.column_names else ds["sequence"]
308+
positions = np.asarray(ds["pos"], dtype=np.int32)
309+
tokens_np = np.asarray([tokenizer(s, add_special_tokens=False)["input_ids"] for s in sequences], dtype=np.int32)
310+
tokens = hax.named(jnp.array(tokens_np), ("batch", "position"))
311+
nucleotide_positions = hax.named(jnp.array(positions), ("batch",))
312+
nucleotide_token_ids = [int(tokenizer.convert_tokens_to_ids(nt)) for nt in "ACGT"]
313+
314+
# Load model
315+
converter = config.hf_checkpoint_converter().replaced(reference_checkpoint=model_dir, tokenizer=tokenizer)
316+
with use_cpu_device():
317+
model = converter.load_pretrained(
318+
config.model_type,
319+
ref=model_dir,
320+
resize_vocab_to_match_tokenizer=False,
321+
dtype=jnp.float32,
322+
)
323+
324+
def logit_fn(x):
325+
return model(x)
326+
327+
# Compute conservation scores
328+
actual = compute_causal_conservation(
329+
tokens=tokens,
330+
logit_function=logit_fn,
331+
nucleotide_positions=nucleotide_positions,
332+
nucleotide_token_ids=nucleotide_token_ids,
333+
)
334+
335+
# Compare with expected scores
336+
expected = np.asarray(ds["score"], dtype=np.float32)
337+
our_scores_np = np.asarray(actual.array, dtype=np.float32)
338+
339+
assert len(our_scores_np) == len(expected) == 8
340+
assert jnp.all(jnp.isfinite(actual.array))
341+
assert np.all(np.isfinite(expected))
342+
343+
# Order parity
344+
assert np.array_equal(np.argsort(-expected), np.argsort(-our_scores_np))
345+
# Value parity within tolerance
346+
np.testing.assert_allclose(our_scores_np, expected, rtol=1e-3, atol=1e-3)
347+
348+
280349
def _assert_batch_variants(alt_array, batch_idx, expected_variants, seq_length, batch_name):
281350
"""Helper to assert variant sequences match expected values for a batch."""
282351
for variant_idx in range(4):

0 commit comments

Comments
 (0)