Skip to content

Commit 85ef370

Browse files
authored
Add LLM running code and docs (#1147)
1 parent efe27e0 commit 85ef370

21 files changed

+4399
-0
lines changed

experiments/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Very experimental code lives here

experiments/llmaat/README.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# LLM as a teacher (llmaat)
2+
3+
4+
The goal is to be able to produce high quality parallel translation datasets with LLMs.
5+
This will allow finetuning NMT models to improve quality and possibly replace teacher training stage by using the LLM-produced data directly.
6+
7+
This work follows the paper [Introducing the NewsPaLM MBR and QE Dataset:
8+
LLM-Generated High-Quality Parallel Data Outperforms Traditional
9+
Web-Crawled Data](https://arxiv.org/pdf/2408.06537).
10+
11+
It also uses the evaluation dataset and the prompt from [WMT24++: Expanding the Language Coverage of WMT24 to 55 Languages & Dialects](https://arxiv.org/html/2502.12404v1).
12+
13+
14+
## Selecting a corpus
15+
16+
The idea is to have a diverse monolingual dataset to translate by an LLM.
17+
18+
It's more efficient to cluster a sample first and then assign clusters based on centroids.
19+
20+
This part is not fully automated.
21+
22+
Steps:
23+
1. Find a big monolingual corpus (100+M sentences). It can be a part of HPLT and NewsCrawl or just one side of our typical merged parallel corpus. It should be deduplicated.
24+
2. Sample a part of it using `shuf -n 1000000`
25+
3. Calculate and save embeddings for the sample (see [notebooks/Select corpus.ipynb]().
26+
We use https://huggingface.co/intfloat/multilingual-e5-small. To speed it up and utilize all GPUs on a machine, we split the sample with `split` and run [scripts/emb_corpus_ddp.py]() with `torchrun --nproc_per_node=8 emb_corpus_ddp.py`
27+
4. Load the embeddings, cluster them with K-Means (5000 clusters) and save centroids to a file
28+
5. Go through the whole corpus and assign clusters based on the closest centroids by doing a NN search. Run `torchrun --nproc_per_node=8 cluster_corpus_ddp.py`, the cluster IDs are saved to a file.
29+
6. Select 1M, 10M and 50M lines by sampling unifromly from the clusters.
30+
31+
The diverse samples are located here: `gs://releng-translations-dev/data/mono-llm/diverse_sample.{1,10,50}M.en.zst`
32+
33+
## Evaluating LLMs
34+
35+
Run [flows/llm_eval_flow.py]() on Mozilla Outerbounds Metaflow:
36+
37+
```bash
38+
export HUGGING_FACE_HUB_TOKEN=...
39+
export WANDB_API_KEY=...
40+
python llm_eval_flow.py --environment=pypi --config config ./configs/config.vllm.json run --experiment greedy --model gemma-3-27b-vllm
41+
```
42+
43+
The evaluation results are available on Weights and Biases: https://wandb.ai/moz-translations/llm-evals?nw=nwuserepavlov
44+
45+
It's possible to add more LLMs and inference methods to [flows/llm_runner.py](). `--model gemma-3-27b-vllm` points to one of the available implementations.
46+
47+
Decoding config can be modified in [flows/configs/config.vllm.json]().
48+
49+
The prompt can be set in the config. Available prompt templates are in [flows/prompts.py]().
50+
51+
It allows running evaluation for multiple language pairs in one run by adding more languages to the config. All pairs are en-xx.
52+
53+
The translation produced by an LLM during evaluation are uploaded to `gs://releng-translations-dev/data/llm-evals/wmt24pp/`.
54+
55+
We caluculate COMET22 and MetricX-24 scores. The size of the MetricX model is set in the step `eval_metricx`.
56+
57+
It's preferable to use vLLM as it has up to 10x higher throughput than the naive inference with HF Transformers.
58+
59+
vLLM config:
60+
61+
```python
62+
{
63+
"batch_size": 1024, # Should be big enough to get the most of the vLLM optimizations
64+
"langs": ["ru_RU"], # Languages to evaluate
65+
"max_tok_alpha": 2.0, # A factor to multiply the number of imput tokens to get the maximum number of output tokens. It might depend on the output language. An optimization.
66+
"prompt": "noomit_fewshot", # Prompt template key
67+
"llm": {
68+
"max_model_len": 1024, # The model context size (maximum total of input and output tokens)
69+
"tensor_parallel_size": 1 # The number of GPUs
70+
},
71+
"decoding": {
72+
"temperature": 0, # Tempreture 0 means greedy decoding, change to activate sampling
73+
"n": 1 # Produce only 1 candidate, increase for QE reranking
74+
}
75+
}
76+
```
77+
78+
## Generating datasets
79+
80+
Run [flows/llm_run_flow.py]() on Mozilla Outerbounds Metaflow:
81+
82+
```bash
83+
export HUGGING_FACE_HUB_TOKEN=...
84+
python llm_run_flow.py \
85+
--environment=pypi --config config ./configs/config.vllm.json run --experiment finetune10M \
86+
--model gemma-3-27b-vllm --data_size 10 --lang ru_RU --part_size 500000 --max-workers 4
87+
```
88+
89+
`--data_size 10` - use 10M dataset to produce 10M translations
90+
91+
`--part_size 500000` - how many lines to process in one Metaflow task
92+
93+
`--max-workers 4` - run 4 tasks max simultaniously (current limitation on the number of GPUs)
94+
95+
The translations will be uploaded to `gs://releng-translations-dev/data/llm/`.
96+
97+
## Quality aware decoding (QE reranking)
98+
99+
Following the NewsPALM paper it's possible to replace regular greedy decoding with sampling of multiple candidates and choosing the best one using MetricX-24-Hybrid quality estimation model.
100+
101+
It required activating the code branch with `pick_best` metaflow step and changing the decoding config (for vllm `decoding.n` > 1, e.g. `decoding.n: 32`).
102+
Decoding will become significantly slower as the model needs to generate N samples instead of one now.
103+
104+
Also, the activated `pick_best` step that runs MetricX model is unoptimized and quite slow now.
105+
106+
## Language codes
107+
108+
We use WMT24++ format of the language codes that include a reference to a country because some prompts require specifying it.
109+
110+
See all available codes in [flows/langs.py]()
111+
112+
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"batch_size": 8,
3+
"langs": ["ru"],
4+
"max_tok_alpha": 2.0,
5+
"decoding": {
6+
"num_beams": 5,
7+
"do_sample": true,
8+
"temperature": 0.6,
9+
"top_p": 0.9
10+
}
11+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"batch_size": 64,
3+
"langs": ["ru"],
4+
"max_tok_alpha": 2.0,
5+
"decoding": {
6+
"num_beams": 1,
7+
"do_sample": false,
8+
"temperature": 0
9+
}
10+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"batch_size": 4,
3+
"langs": ["ru"],
4+
"max_tok_alpha": 2.0,
5+
"decoding": {
6+
"num_beams": 1,
7+
"do_sample": true,
8+
"temperature": 1.0,
9+
"top_p": 0.9,
10+
"num_return_sequences": 16
11+
}
12+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"batch_size": 64,
3+
"langs": ["ru"],
4+
"max_tok_alpha": 2.0,
5+
"decoding": {
6+
"num_beams": 1,
7+
"do_sample": true,
8+
"temperature": 0.6,
9+
"top_p": 0.9
10+
}
11+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"batch_size": 1024,
3+
"langs": ["ru_RU"],
4+
"max_tok_alpha": 2.0,
5+
"prompt": "noomit_fewshot",
6+
"llm": {
7+
"max_model_len": 1024,
8+
"tensor_parallel_size": 1
9+
},
10+
"decoding": {
11+
"temperature": 0,
12+
"n": 1
13+
}
14+
}

experiments/llmaat/flows/evals.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from typing import List
2+
3+
EVAL_PAIRS = (
4+
"en-ar_EG",
5+
"en-ar_SA",
6+
"en-bg_BG",
7+
"en-bn_IN",
8+
"en-ca_ES",
9+
"en-cs_CZ",
10+
"en-da_DK",
11+
"en-de_DE",
12+
"en-el_GR",
13+
"en-es_MX",
14+
"en-et_EE",
15+
"en-fa_IR",
16+
"en-fi_FI",
17+
"en-fil_PH",
18+
"en-fr_CA",
19+
"en-fr_FR",
20+
"en-gu_IN",
21+
"en-he_IL",
22+
"en-hi_IN",
23+
"en-hr_HR",
24+
"en-hu_HU",
25+
"en-id_ID",
26+
"en-is_IS",
27+
"en-it_IT",
28+
"en-ja_JP",
29+
"en-kn_IN",
30+
"en-ko_KR",
31+
"en-lt_LT",
32+
"en-lv_LV",
33+
"en-ml_IN",
34+
"en-mr_IN",
35+
"en-nl_NL",
36+
"en-no_NO",
37+
"en-pa_IN",
38+
"en-pl_PL",
39+
"en-pt_BR",
40+
"en-pt_PT",
41+
"en-ro_RO",
42+
"en-ru_RU",
43+
"en-sk_SK",
44+
"en-sl_SI",
45+
"en-sr_RS",
46+
"en-sv_SE",
47+
"en-sw_KE",
48+
"en-sw_TZ",
49+
"en-ta_IN",
50+
"en-te_IN",
51+
"en-th_TH",
52+
"en-tr_TR",
53+
"en-uk_UA",
54+
"en-ur_PK",
55+
"en-vi_VN",
56+
"en-zh_CN",
57+
"en-zh_TW",
58+
"en-zu_ZA",
59+
)
60+
61+
62+
lang_map = {
63+
pair.split("_")[0].split("-")[1]: pair
64+
for pair in EVAL_PAIRS
65+
if pair.split("_")[1] not in {"TW", "PT", "CA", "EG", "TZ"}
66+
}
67+
68+
69+
def load_data(lang):
70+
from datasets import load_dataset
71+
72+
#
73+
# if lang not in lang_map:
74+
# raise ValueError(f"Language {lang} is not supported")
75+
76+
# Login using e.g. `huggingface-cli login` to access this dataset
77+
print(f"Downloading dataset for {lang}")
78+
lp = f"en-{lang}"
79+
ds = load_dataset("google/wmt24pp", lp)
80+
filtered = ds.filter(lambda ex: not ex["is_bad_source"] and ex["lp"] == lp)["train"]
81+
return filtered["source"], filtered["target"]
82+
83+
84+
def eval_comet(source_texts, target_translations, target_references):
85+
import comet
86+
87+
comet_checkpoint = comet.download_model("Unbabel/wmt22-comet-da")
88+
comet_model = comet.load_from_checkpoint(comet_checkpoint)
89+
comet_data = []
90+
for source, target, target_ref in zip(source_texts, target_translations, target_references):
91+
comet_data.append({"src": source, "mt": target, "ref": target_ref})
92+
comet_results = comet_model.predict(comet_data, gpus=1)
93+
return round(comet_results.system_score * 100, 2)
94+
95+
96+
def eval_metricx(
97+
source_texts,
98+
target_translations,
99+
target_references,
100+
model_size="xl",
101+
fp16=True,
102+
batch_size=8,
103+
):
104+
"""
105+
https://huggingface.co/google/metricx-24-hybrid-xxl-v2p6
106+
107+
Available model sizes: "large" (1.2B), "xl" (3.7B), "xxl" (13b)
108+
"""
109+
110+
import json
111+
from statistics import mean
112+
from metricx.predict import predict
113+
114+
with open("input.jsonl", "w") as in_file:
115+
for source, target, target_ref in zip(
116+
source_texts, target_translations, target_references
117+
):
118+
ex_dict = {"source": source, "reference": target_ref, "hypothesis": target}
119+
in_file.write(json.dumps(ex_dict) + "\n")
120+
121+
model_name = f"google/metricx-24-hybrid-{model_size}-v2p6"
122+
if fp16:
123+
model_name += "-bfloat16"
124+
125+
# batch size is divided by number of GPUs, set equal or higher
126+
print(f"Running evaluation with {model_name} reference based")
127+
predict(
128+
tokenizer=f"google/mt5-{model_size}",
129+
model_name_or_path=model_name,
130+
max_input_length=1536,
131+
batch_size=batch_size,
132+
input_file="input.jsonl",
133+
output_file="output.ref.jsonl",
134+
qe=False,
135+
)
136+
137+
print(f"Running evaluation with {model_name} reference free QE")
138+
predict(
139+
tokenizer=f"google/mt5-{model_size}",
140+
model_name_or_path=model_name,
141+
max_input_length=1536,
142+
batch_size=batch_size,
143+
input_file="input.jsonl",
144+
output_file="output.qe.jsonl",
145+
qe=True,
146+
)
147+
148+
with open("output.qe.jsonl") as out_qe:
149+
qe_score = mean([float(json.loads(line)["prediction"]) for line in out_qe])
150+
with open("output.ref.jsonl") as out_ref:
151+
ref_score = mean([float(json.loads(line)["prediction"]) for line in out_ref])
152+
153+
return {f"metricx24-{model_size}-qe": qe_score, f"metricx24-{model_size}": ref_score}
154+
155+
156+
def select_best(
157+
source: List[str], translations: List[List[str]], model_size="xl", fp16=True, batch_size=8
158+
) -> List[str]:
159+
import json
160+
from metricx.predict import predict
161+
162+
with open("input.jsonl", "w") as in_file:
163+
for (
164+
source,
165+
tr_candidates,
166+
) in zip(source, translations):
167+
for translation in tr_candidates:
168+
ex_dict = {"source": source, "hypothesis": translation}
169+
in_file.write(json.dumps(ex_dict) + "\n")
170+
171+
model_name = f"google/metricx-24-hybrid-{model_size}-v2p6"
172+
if fp16:
173+
model_name += "-bfloat16"
174+
175+
print(f"Running evaluation with {model_name} reference free QE")
176+
predict(
177+
tokenizer=f"google/mt5-{model_size}",
178+
model_name_or_path=model_name,
179+
max_input_length=1536,
180+
batch_size=batch_size,
181+
input_file="input.jsonl",
182+
output_file="output.qe.jsonl",
183+
qe=True,
184+
)
185+
186+
with open("output.qe.jsonl") as out_qe:
187+
scores = [json.loads(line)["prediction"] for line in out_qe]
188+
189+
num_candidates = len(translations[0])
190+
191+
best = []
192+
for i, candidates in enumerate(translations):
193+
start = i * num_candidates
194+
candidate_scores = scores[start : start + num_candidates]
195+
best_idx = candidate_scores.index(min(candidate_scores))
196+
best.append(candidates[best_idx])
197+
return best
198+
199+
200+
def _run_cmd(cmd):
201+
import subprocess
202+
203+
try:
204+
subprocess.run(cmd, check=True, capture_output=True, shell=True)
205+
except subprocess.CalledProcessError as e:
206+
print("STDOUT:", e.stdout.decode("utf-8", errors="replace"))
207+
print("STDERR:", e.stderr.decode("utf-8", errors="replace"))
208+
raise

0 commit comments

Comments
 (0)