|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
| 16 | +import numpy as np |
15 | 17 | import pytest |
16 | 18 | import jax |
17 | 19 | import jax.numpy as jnp |
18 | 20 | import haliax as hax |
| 21 | +from datasets import load_dataset |
| 22 | +from huggingface_hub import snapshot_download |
| 23 | +from levanter.models.llama import LlamaConfig |
| 24 | +from levanter.utils.jax_utils import use_cpu_device |
| 25 | +from transformers import PretrainedConfig as HfConfig, AutoTokenizer |
19 | 26 | from experiments.plantcad.evaluation import ( |
20 | 27 | create_alternate_sequences, |
21 | 28 | compute_sequence_logprob, |
@@ -277,6 +284,68 @@ def test_compute_causal_conservation(): |
277 | 284 | ) |
278 | 285 |
|
279 | 286 |
|
| 287 | +def test_compute_causal_conservation_accuracy(): |
| 288 | + """End-to-end parity test against reference scores. |
| 289 | +
|
| 290 | + Reference scores come from https://github.com/Open-Athena/biofoundation/commit/23f6745defdd54cac09b43c066f249789bf74d56 |
| 291 | + """ |
| 292 | + # Download model and dataset |
| 293 | + data_path = snapshot_download( |
| 294 | + repo_id="plantcad/ci", |
| 295 | + repo_type="dataset", |
| 296 | + allow_patterns="unit_tests/evolutionary_constraint/ref_logprob_clm_sim/*", |
| 297 | + ) |
| 298 | + ds = load_dataset("plantcad/ci", name="ut_ec_ref_logprob_clm_sim", split="train") |
| 299 | + model_dir = os.path.join(data_path, "unit_tests/evolutionary_constraint/ref_logprob_clm_sim/model") |
| 300 | + |
| 301 | + # Load tokenizer and config |
| 302 | + hf_config = HfConfig.from_pretrained(model_dir) |
| 303 | + config = LlamaConfig.from_hf_config(hf_config) |
| 304 | + tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| 305 | + |
| 306 | + # Load sequences and positions |
| 307 | + sequences = ds["seq"] if "seq" in ds.column_names else ds["sequence"] |
| 308 | + positions = np.asarray(ds["pos"], dtype=np.int32) |
| 309 | + tokens_np = np.asarray([tokenizer(s, add_special_tokens=False)["input_ids"] for s in sequences], dtype=np.int32) |
| 310 | + tokens = hax.named(jnp.array(tokens_np), ("batch", "position")) |
| 311 | + nucleotide_positions = hax.named(jnp.array(positions), ("batch",)) |
| 312 | + nucleotide_token_ids = [int(tokenizer.convert_tokens_to_ids(nt)) for nt in "ACGT"] |
| 313 | + |
| 314 | + # Load model |
| 315 | + converter = config.hf_checkpoint_converter().replaced(reference_checkpoint=model_dir, tokenizer=tokenizer) |
| 316 | + with use_cpu_device(): |
| 317 | + model = converter.load_pretrained( |
| 318 | + config.model_type, |
| 319 | + ref=model_dir, |
| 320 | + resize_vocab_to_match_tokenizer=False, |
| 321 | + dtype=jnp.float32, |
| 322 | + ) |
| 323 | + |
| 324 | + def logit_fn(x): |
| 325 | + return model(x) |
| 326 | + |
| 327 | + # Compute conservation scores |
| 328 | + actual = compute_causal_conservation( |
| 329 | + tokens=tokens, |
| 330 | + logit_function=logit_fn, |
| 331 | + nucleotide_positions=nucleotide_positions, |
| 332 | + nucleotide_token_ids=nucleotide_token_ids, |
| 333 | + ) |
| 334 | + |
| 335 | + # Compare with expected scores |
| 336 | + expected = np.asarray(ds["score"], dtype=np.float32) |
| 337 | + our_scores_np = np.asarray(actual.array, dtype=np.float32) |
| 338 | + |
| 339 | + assert len(our_scores_np) == len(expected) == 8 |
| 340 | + assert jnp.all(jnp.isfinite(actual.array)) |
| 341 | + assert np.all(np.isfinite(expected)) |
| 342 | + |
| 343 | + # Order parity |
| 344 | + assert np.array_equal(np.argsort(-expected), np.argsort(-our_scores_np)) |
| 345 | + # Value parity within tolerance |
| 346 | + np.testing.assert_allclose(our_scores_np, expected, rtol=1e-3, atol=1e-3) |
| 347 | + |
| 348 | + |
280 | 349 | def _assert_batch_variants(alt_array, batch_idx, expected_variants, seq_length, batch_name): |
281 | 350 | """Helper to assert variant sequences match expected values for a batch.""" |
282 | 351 | for variant_idx in range(4): |
|
0 commit comments