|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from fray.v2.types import ResourceConfig |
| 5 | + |
| 6 | +from experiments.defaults import default_raw_validation_sets |
| 7 | +from experiments.evals.fineweb2_multilingual import fineweb2_multilingual_raw_validation_sets |
| 8 | +from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap |
| 9 | +from marin.execution.executor import executor_main |
| 10 | + |
| 11 | +RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"]) |
| 12 | +MAX_DOCS_PER_DATASET = 256 |
| 13 | +MAX_DOC_BYTES = 32_768 |
| 14 | + |
| 15 | +DATASETS = { |
| 16 | + **default_raw_validation_sets(), |
| 17 | + **fineweb2_multilingual_raw_validation_sets(), |
| 18 | +} |
| 19 | + |
| 20 | +MARIN_MODEL = GapFinderModelConfig( |
| 21 | + checkpoint_path="marin-community/marin-8b-base", |
| 22 | + checkpoint_is_hf=True, |
| 23 | + tokenizer="meta-llama/Llama-3.1-8B", |
| 24 | +) |
| 25 | + |
| 26 | +MARIN_VS_LLAMA = default_model_perplexity_gap( |
| 27 | + name="fineweb2-multilingual-marin-8b-base-vs-llama-3.1-8b-base-doccap256", |
| 28 | + model_a=MARIN_MODEL, |
| 29 | + model_b=GapFinderModelConfig( |
| 30 | + checkpoint_path="meta-llama/Llama-3.1-8B", |
| 31 | + checkpoint_is_hf=True, |
| 32 | + tokenizer="meta-llama/Llama-3.1-8B", |
| 33 | + ), |
| 34 | + datasets=DATASETS, |
| 35 | + resource_config=RESOURCE_CONFIG, |
| 36 | + per_device_batch_size=4, |
| 37 | + max_eval_length=4096, |
| 38 | + max_docs_per_dataset=MAX_DOCS_PER_DATASET, |
| 39 | + max_doc_bytes=MAX_DOC_BYTES, |
| 40 | + wandb_tags=[ |
| 41 | + "eval=perplexity-gap", |
| 42 | + "rerun=fineweb2-multilingual", |
| 43 | + "model_a=marin-community/marin-8b-base", |
| 44 | + "model_b=meta-llama/Llama-3.1-8B", |
| 45 | + "dataset_bundle=default_raw_plus_fineweb2_multilingual", |
| 46 | + "region=us-central1", |
| 47 | + f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", |
| 48 | + ], |
| 49 | +) |
| 50 | + |
| 51 | +MARIN_VS_QWEN3 = default_model_perplexity_gap( |
| 52 | + name="fineweb2-multilingual-marin-8b-base-vs-qwen3-8b-base-doccap256", |
| 53 | + model_a=MARIN_MODEL, |
| 54 | + model_b=GapFinderModelConfig( |
| 55 | + checkpoint_path="Qwen/Qwen3-8B-Base", |
| 56 | + checkpoint_is_hf=True, |
| 57 | + tokenizer="Qwen/Qwen3-8B", |
| 58 | + ), |
| 59 | + datasets=DATASETS, |
| 60 | + resource_config=RESOURCE_CONFIG, |
| 61 | + per_device_batch_size=4, |
| 62 | + max_eval_length=4096, |
| 63 | + max_docs_per_dataset=MAX_DOCS_PER_DATASET, |
| 64 | + max_doc_bytes=MAX_DOC_BYTES, |
| 65 | + wandb_tags=[ |
| 66 | + "eval=perplexity-gap", |
| 67 | + "rerun=fineweb2-multilingual", |
| 68 | + "model_a=marin-community/marin-8b-base", |
| 69 | + "model_b=Qwen/Qwen3-8B-Base", |
| 70 | + "dataset_bundle=default_raw_plus_fineweb2_multilingual", |
| 71 | + "region=us-central1", |
| 72 | + f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", |
| 73 | + ], |
| 74 | +) |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + executor_main( |
| 79 | + [MARIN_VS_LLAMA, MARIN_VS_QWEN3], |
| 80 | + description="Run Marin perplexity-gap reports with FineWeb2 multilingual held-out eval sets.", |
| 81 | + ) |
0 commit comments