|
| 1 | +from typing import List |
| 2 | + |
| 3 | +EVAL_PAIRS = ( |
| 4 | + "en-ar_EG", |
| 5 | + "en-ar_SA", |
| 6 | + "en-bg_BG", |
| 7 | + "en-bn_IN", |
| 8 | + "en-ca_ES", |
| 9 | + "en-cs_CZ", |
| 10 | + "en-da_DK", |
| 11 | + "en-de_DE", |
| 12 | + "en-el_GR", |
| 13 | + "en-es_MX", |
| 14 | + "en-et_EE", |
| 15 | + "en-fa_IR", |
| 16 | + "en-fi_FI", |
| 17 | + "en-fil_PH", |
| 18 | + "en-fr_CA", |
| 19 | + "en-fr_FR", |
| 20 | + "en-gu_IN", |
| 21 | + "en-he_IL", |
| 22 | + "en-hi_IN", |
| 23 | + "en-hr_HR", |
| 24 | + "en-hu_HU", |
| 25 | + "en-id_ID", |
| 26 | + "en-is_IS", |
| 27 | + "en-it_IT", |
| 28 | + "en-ja_JP", |
| 29 | + "en-kn_IN", |
| 30 | + "en-ko_KR", |
| 31 | + "en-lt_LT", |
| 32 | + "en-lv_LV", |
| 33 | + "en-ml_IN", |
| 34 | + "en-mr_IN", |
| 35 | + "en-nl_NL", |
| 36 | + "en-no_NO", |
| 37 | + "en-pa_IN", |
| 38 | + "en-pl_PL", |
| 39 | + "en-pt_BR", |
| 40 | + "en-pt_PT", |
| 41 | + "en-ro_RO", |
| 42 | + "en-ru_RU", |
| 43 | + "en-sk_SK", |
| 44 | + "en-sl_SI", |
| 45 | + "en-sr_RS", |
| 46 | + "en-sv_SE", |
| 47 | + "en-sw_KE", |
| 48 | + "en-sw_TZ", |
| 49 | + "en-ta_IN", |
| 50 | + "en-te_IN", |
| 51 | + "en-th_TH", |
| 52 | + "en-tr_TR", |
| 53 | + "en-uk_UA", |
| 54 | + "en-ur_PK", |
| 55 | + "en-vi_VN", |
| 56 | + "en-zh_CN", |
| 57 | + "en-zh_TW", |
| 58 | + "en-zu_ZA", |
| 59 | +) |
| 60 | + |
| 61 | + |
| 62 | +lang_map = { |
| 63 | + pair.split("_")[0].split("-")[1]: pair |
| 64 | + for pair in EVAL_PAIRS |
| 65 | + if pair.split("_")[1] not in {"TW", "PT", "CA", "EG", "TZ"} |
| 66 | +} |
| 67 | + |
| 68 | + |
| 69 | +def load_data(lang): |
| 70 | + from datasets import load_dataset |
| 71 | + |
| 72 | + # |
| 73 | + # if lang not in lang_map: |
| 74 | + # raise ValueError(f"Language {lang} is not supported") |
| 75 | + |
| 76 | + # Login using e.g. `huggingface-cli login` to access this dataset |
| 77 | + print(f"Downloading dataset for {lang}") |
| 78 | + lp = f"en-{lang}" |
| 79 | + ds = load_dataset("google/wmt24pp", lp) |
| 80 | + filtered = ds.filter(lambda ex: not ex["is_bad_source"] and ex["lp"] == lp)["train"] |
| 81 | + return filtered["source"], filtered["target"] |
| 82 | + |
| 83 | + |
| 84 | +def eval_comet(source_texts, target_translations, target_references): |
| 85 | + import comet |
| 86 | + |
| 87 | + comet_checkpoint = comet.download_model("Unbabel/wmt22-comet-da") |
| 88 | + comet_model = comet.load_from_checkpoint(comet_checkpoint) |
| 89 | + comet_data = [] |
| 90 | + for source, target, target_ref in zip(source_texts, target_translations, target_references): |
| 91 | + comet_data.append({"src": source, "mt": target, "ref": target_ref}) |
| 92 | + comet_results = comet_model.predict(comet_data, gpus=1) |
| 93 | + return round(comet_results.system_score * 100, 2) |
| 94 | + |
| 95 | + |
| 96 | +def eval_metricx( |
| 97 | + source_texts, |
| 98 | + target_translations, |
| 99 | + target_references, |
| 100 | + model_size="xl", |
| 101 | + fp16=True, |
| 102 | + batch_size=8, |
| 103 | +): |
| 104 | + """ |
| 105 | + https://huggingface.co/google/metricx-24-hybrid-xxl-v2p6 |
| 106 | +
|
| 107 | + Available model sizes: "large" (1.2B), "xl" (3.7B), "xxl" (13b) |
| 108 | + """ |
| 109 | + |
| 110 | + import json |
| 111 | + from statistics import mean |
| 112 | + from metricx.predict import predict |
| 113 | + |
| 114 | + with open("input.jsonl", "w") as in_file: |
| 115 | + for source, target, target_ref in zip( |
| 116 | + source_texts, target_translations, target_references |
| 117 | + ): |
| 118 | + ex_dict = {"source": source, "reference": target_ref, "hypothesis": target} |
| 119 | + in_file.write(json.dumps(ex_dict) + "\n") |
| 120 | + |
| 121 | + model_name = f"google/metricx-24-hybrid-{model_size}-v2p6" |
| 122 | + if fp16: |
| 123 | + model_name += "-bfloat16" |
| 124 | + |
| 125 | + # batch size is divided by number of GPUs, set equal or higher |
| 126 | + print(f"Running evaluation with {model_name} reference based") |
| 127 | + predict( |
| 128 | + tokenizer=f"google/mt5-{model_size}", |
| 129 | + model_name_or_path=model_name, |
| 130 | + max_input_length=1536, |
| 131 | + batch_size=batch_size, |
| 132 | + input_file="input.jsonl", |
| 133 | + output_file="output.ref.jsonl", |
| 134 | + qe=False, |
| 135 | + ) |
| 136 | + |
| 137 | + print(f"Running evaluation with {model_name} reference free QE") |
| 138 | + predict( |
| 139 | + tokenizer=f"google/mt5-{model_size}", |
| 140 | + model_name_or_path=model_name, |
| 141 | + max_input_length=1536, |
| 142 | + batch_size=batch_size, |
| 143 | + input_file="input.jsonl", |
| 144 | + output_file="output.qe.jsonl", |
| 145 | + qe=True, |
| 146 | + ) |
| 147 | + |
| 148 | + with open("output.qe.jsonl") as out_qe: |
| 149 | + qe_score = mean([float(json.loads(line)["prediction"]) for line in out_qe]) |
| 150 | + with open("output.ref.jsonl") as out_ref: |
| 151 | + ref_score = mean([float(json.loads(line)["prediction"]) for line in out_ref]) |
| 152 | + |
| 153 | + return {f"metricx24-{model_size}-qe": qe_score, f"metricx24-{model_size}": ref_score} |
| 154 | + |
| 155 | + |
| 156 | +def select_best( |
| 157 | + source: List[str], translations: List[List[str]], model_size="xl", fp16=True, batch_size=8 |
| 158 | +) -> List[str]: |
| 159 | + import json |
| 160 | + from metricx.predict import predict |
| 161 | + |
| 162 | + with open("input.jsonl", "w") as in_file: |
| 163 | + for ( |
| 164 | + source, |
| 165 | + tr_candidates, |
| 166 | + ) in zip(source, translations): |
| 167 | + for translation in tr_candidates: |
| 168 | + ex_dict = {"source": source, "hypothesis": translation} |
| 169 | + in_file.write(json.dumps(ex_dict) + "\n") |
| 170 | + |
| 171 | + model_name = f"google/metricx-24-hybrid-{model_size}-v2p6" |
| 172 | + if fp16: |
| 173 | + model_name += "-bfloat16" |
| 174 | + |
| 175 | + print(f"Running evaluation with {model_name} reference free QE") |
| 176 | + predict( |
| 177 | + tokenizer=f"google/mt5-{model_size}", |
| 178 | + model_name_or_path=model_name, |
| 179 | + max_input_length=1536, |
| 180 | + batch_size=batch_size, |
| 181 | + input_file="input.jsonl", |
| 182 | + output_file="output.qe.jsonl", |
| 183 | + qe=True, |
| 184 | + ) |
| 185 | + |
| 186 | + with open("output.qe.jsonl") as out_qe: |
| 187 | + scores = [json.loads(line)["prediction"] for line in out_qe] |
| 188 | + |
| 189 | + num_candidates = len(translations[0]) |
| 190 | + |
| 191 | + best = [] |
| 192 | + for i, candidates in enumerate(translations): |
| 193 | + start = i * num_candidates |
| 194 | + candidate_scores = scores[start : start + num_candidates] |
| 195 | + best_idx = candidate_scores.index(min(candidate_scores)) |
| 196 | + best.append(candidates[best_idx]) |
| 197 | + return best |
| 198 | + |
| 199 | + |
| 200 | +def _run_cmd(cmd): |
| 201 | + import subprocess |
| 202 | + |
| 203 | + try: |
| 204 | + subprocess.run(cmd, check=True, capture_output=True, shell=True) |
| 205 | + except subprocess.CalledProcessError as e: |
| 206 | + print("STDOUT:", e.stdout.decode("utf-8", errors="replace")) |
| 207 | + print("STDERR:", e.stderr.decode("utf-8", errors="replace")) |
| 208 | + raise |
0 commit comments