Skip to content

Commit 74abb88

Browse files
authored
Target side deduplication based on bicleaner scores (#1201)
* Target side deduplication using Bicleaner scores * Add scores.zst artifact * bicleaner: create a dummy scores when no filtering * Fix linting * Yield dummy scores when scores is not provided This allows to use merge-parallel with devsets * Fix linting * Rename Bicleaner step .scores.zst to best-scores * Adapt test_merge_corpus to dedup by target Extend the cases to check that different sentence pairs with different sources but same target are deduplicated by target and the source with the best scores kept. * Move generator logic to a function
1 parent 455225a commit 74abb88

File tree

6 files changed

+222
-75
lines changed

6 files changed

+222
-75
lines changed

pipeline/bicleaner/bicleaner.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ if [ "${bicleaner_threshold}" == "0" ] || [ "${bicleaner_threshold}" == "0.0" ];
3434
echo "Threshold is 0, skipping filtering"
3535
cp "${corpus_prefix}.${SRC}.zst" "${output_prefix}.${SRC}.zst"
3636
cp "${corpus_prefix}.${TRG}.zst" "${output_prefix}.${TRG}.zst"
37+
# Create a dummy best-scores.zst, if no filtering everyone gets perfect score
38+
# this is needed for target side dedup in merge-parallel
39+
num_sents=$(zstdcat "${corpus_prefix}.${TRG}.zst" | wc -l)
40+
awk -v n=$num_sents 'BEGIN {for(i=0;i<n;i++) print "1.0";}' | zstdmt >"${output_prefix}.best-scores.zst"
3741
else
3842

3943
export scol=1
@@ -98,7 +102,8 @@ else
98102
echo "### Writing output corpus"
99103
zstdmt -dc "${output_prefix}.best.zst" |
100104
tee >(cut -f1 | zstdmt >"${output_prefix}.${SRC}.zst") |
101-
cut -f2 | zstdmt >"${output_prefix}.${TRG}.zst"
105+
tee >(cut -f2 | zstdmt >"${output_prefix}.${TRG}.zst") |
106+
cut -f3 | zstdmt >"${output_prefix}.best-scores.zst"
102107

103108
# do not delete intermediate files to inspect them and tune the threshold
104109
fi

pipeline/clean/merge-parallel.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pipeline.common.datasets import (
2323
FilteringStep,
2424
Statistics,
25-
WeakStringSet,
25+
WeakStringDict,
2626
shuffle_with_max_lines,
2727
)
2828
from pipeline.common.downloads import get_human_readable_file_size, read_lines, write_lines
@@ -58,17 +58,24 @@ def log_dataset(location: str):
5858
logger.info(f"Reading dataset {location}")
5959

6060

61+
def dummy_score_generator():
62+
for i in iter(int, 1):
63+
yield "1.0"
64+
65+
6166
class DeduplicateCorpus:
6267
def __init__(
6368
self,
6469
datasets_src: list[Path],
6570
datasets_trg: list[Path],
71+
datasets_scores: list[Path],
6672
src_outpath: Path,
6773
trg_outpath: Path,
6874
stats: FilteringStatistics,
6975
) -> None:
7076
self.datasets_src: list[Path] = datasets_src
7177
self.datasets_trg: list[Path] = datasets_trg
78+
self.datasets_scores: list[Path] = datasets_scores
7279
self.src_outpath: Path = src_outpath
7380
self.trg_outpath: Path = trg_outpath
7481
self.stats: FilteringStatistics = stats
@@ -105,30 +112,63 @@ def run(
105112
stats.final_truncated.kept = stats.parallel_corpus.kept
106113
stats.final_truncated.visited = stats.parallel_corpus.kept
107114

108-
def yield_lines_tuple(self, stack: ExitStack) -> Generator[tuple[str, str], None, None]:
109-
strings_seen = WeakStringSet()
110-
stats = self.stats
115+
def on_enter_location(self, location):
116+
log_dataset(location)
117+
self.dataset_stats = self.stats.add_parallel_dataset(location)
118+
119+
def _yield_lines(self, stack: ExitStack, add_stats: bool = False):
120+
if add_stats:
121+
enter_location_func = self.on_enter_location
122+
else:
123+
enter_location_func = log_dataset
124+
111125
src_lines: Generator[str, None, None] = stack.enter_context(
112-
read_lines(self.datasets_src, on_enter_location=self.on_enter_location)
126+
read_lines(self.datasets_src, on_enter_location=enter_location_func)
113127
)
114128
trg_lines: Generator[str, None, None] = stack.enter_context(
115129
read_lines(self.datasets_trg, on_enter_location=log_dataset)
116130
)
131+
if self.datasets_scores == []:
132+
logger.info("No scores found, deduping without score")
133+
scores_lines = dummy_score_generator()
134+
else:
135+
scores_lines: Generator[str, None, None] = stack.enter_context(
136+
read_lines(self.datasets_scores, on_enter_location=log_dataset)
137+
)
117138

118-
for src_line, trg_line in zip(src_lines, trg_lines):
119-
# No separator is needed as the newline is included.
120-
line = src_line + trg_line
139+
for i, (src_line, trg_line, score_line) in enumerate(
140+
zip(src_lines, trg_lines, scores_lines)
141+
):
142+
try:
143+
score = float(score_line)
144+
except ValueError as e:
145+
raise ValueError(f"Could not parse score in line {i}") from e
121146

122-
if line in strings_seen:
123-
stats.parallel_corpus.filtered += 1
124-
self.dataset_stats.filtered += 1
125-
else:
147+
yield src_line, trg_line, score
148+
149+
def yield_lines_tuple(self, stack: ExitStack) -> Generator[tuple[str, str], None, None]:
150+
strings_seen = WeakStringDict()
151+
stats = self.stats
152+
for src_line, trg_line, score in self._yield_lines(stack):
153+
# store all possible targets
154+
# for all the sentence pairs that have the same target, keep the best score
155+
if trg_line not in strings_seen or strings_seen[trg_line] < score:
156+
strings_seen[trg_line] = score
157+
158+
for src_line, trg_line, score in self._yield_lines(stack, add_stats=True):
159+
# When a target has the same score as stored, therefore the best score
160+
# we keep it
161+
if trg_line in strings_seen and strings_seen[trg_line] == score:
126162
stats.parallel_corpus.kept += 1
127163
self.dataset_stats.kept += 1
128-
129-
strings_seen.add(line)
164+
# the item is removed from the dict to avoid keeping two sentence pairs
165+
# that have the same target AND the same score
166+
del strings_seen[trg_line]
130167

131168
yield src_line, trg_line
169+
else:
170+
stats.parallel_corpus.filtered += 1
171+
self.dataset_stats.filtered += 1
132172

133173
def yield_lines_string(self, stack: ExitStack) -> Generator[str, None, None]:
134174
for src_line, trg_line in self.yield_lines_tuple(stack):
@@ -139,10 +179,6 @@ def yield_lines_string(self, stack: ExitStack) -> Generator[str, None, None]:
139179
else:
140180
yield f"{src_line}\t{trg_line}"
141181

142-
def on_enter_location(self, location):
143-
log_dataset(location)
144-
self.dataset_stats = self.stats.add_parallel_dataset(location)
145-
146182

147183
def sample_corpus(
148184
artifacts: Path, name: str, sample_size: int, src_outpath: Path, trg_outpath: Path
@@ -204,24 +240,43 @@ def get_datasets(src: str, trg: str, datasets_glob: str):
204240
dataset_paths: list[str] = glob(datasets_glob)
205241
datasets_src: list[Path] = []
206242
datasets_trg: list[Path] = []
243+
datasets_scores: list[Path] = []
207244
dataset_paths.sort()
208245

209246
total_corpus_bytes = 0
210247

211248
for dataset in dataset_paths:
212249
path = Path(dataset)
250+
countbytes = True
213251
if dataset.endswith(f"{src}.zst"):
214252
datasets_src.append(path)
215253
elif dataset.endswith(f"{trg}.zst"):
216254
datasets_trg.append(path)
255+
elif dataset.endswith(".best-scores.zst"):
256+
datasets_scores.append(path)
257+
countbytes = False
217258
else:
218259
raise Exception(f"Dataset does not match naming scheme: {dataset}")
219260

220-
formatted_size, bytes = get_human_readable_file_size(path)
221-
logger.info(f" - {path} ({formatted_size})")
222-
total_corpus_bytes += bytes
261+
# Do not count bytes of the scores
262+
if countbytes:
263+
formatted_size, bytes = get_human_readable_file_size(path)
264+
logger.info(f" - {path} ({formatted_size})")
265+
total_corpus_bytes += bytes
266+
267+
# Fail if different amount of files per dataset
268+
# but do not file if no .scores are provided (when running for devsets)
269+
if (
270+
len(datasets_src) != len(datasets_trg) or len(datasets_src) != len(datasets_scores)
271+
) and datasets_scores != []:
272+
logger.info(datasets_src)
273+
logger.info(datasets_trg)
274+
logger.info(datasets_scores)
275+
raise Exception(
276+
f"Number of files per dataset is different src: {len(datasets_src)} trg: {len(datasets_trg)} scores: {len(datasets_scores)}"
277+
)
223278

224-
return datasets_src, datasets_trg, total_corpus_bytes
279+
return datasets_src, datasets_trg, datasets_scores, total_corpus_bytes
225280

226281

227282
def main() -> None:
@@ -273,7 +328,7 @@ def main() -> None:
273328

274329
args = parser.parse_args()
275330

276-
datasets_src, datasets_trg, total_corpus_bytes = get_datasets(
331+
datasets_src, datasets_trg, datasets_scores, total_corpus_bytes = get_datasets(
277332
args.src, args.trg, args.datasets_glob
278333
)
279334

@@ -291,6 +346,7 @@ def main() -> None:
291346
deduplicate_corpus = DeduplicateCorpus(
292347
datasets_src,
293348
datasets_trg,
349+
datasets_scores,
294350
src_outpath,
295351
trg_outpath,
296352
stats,

pipeline/common/datasets.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from io import TextIOWrapper
1010
from pathlib import Path
1111
from random import Random
12-
from typing import Callable, Iterator, Literal, Optional, Set, Union
12+
from typing import Callable, Iterator, Literal, Optional, Set, Union, Dict
1313
from urllib.parse import urlparse
1414
import unicodedata
1515

@@ -448,6 +448,57 @@ def _hash_string(string: str) -> int:
448448
return hash(cleaned_line)
449449

450450

451+
class WeakStringDict(Dict):
452+
"""
453+
A Dict that weakly holds on to key strings by storing a hashed `int`. Using this class
454+
makes it easy to see if a string is duplicated across large datasets without holding
455+
the entire set of strings in memory.
456+
457+
This is an alternate version of WeakStringSet that also stores a float value (score)
458+
associated to the string.
459+
460+
Usage:
461+
unique_strings = WeakStringDict()
462+
unique_strings["string a"] = 0.78
463+
unique_strings["string b"] = 0.911
464+
465+
assert "string a" in unique_strings
466+
assert "string b" in unique_strings
467+
assert "string c" not in unique_strings
468+
"""
469+
470+
def __init__(self, iter: Optional[Iterable[str]] = None) -> None:
471+
if iter:
472+
super().__init__((WeakStringDict._hash_string(string) for string in iter))
473+
else:
474+
super().__init__()
475+
476+
def __contains__(self, string: str) -> bool:
477+
return super().__contains__(WeakStringDict._hash_string(string))
478+
479+
def __setitem__(self, string: str, val: float) -> None:
480+
"""
481+
Add/set a string the weak dict as key and its value associated.
482+
The strings are stored uniquely based on their
483+
contents with the whitespace surrounding them stripped.
484+
"""
485+
super().__setitem__(WeakStringDict._hash_string(string), val)
486+
487+
def __delitem__(self, string: str):
488+
super().__delitem__(WeakStringDict._hash_string(string))
489+
490+
def __getitem__(self, string: str) -> float:
491+
return super().__getitem__(WeakStringDict._hash_string(string))
492+
493+
def _hash_string(string: str) -> int:
494+
"""
495+
Return a hash of a line. The line has its whitespace stripped and text representation
496+
normalized to ensure a consistent representation.
497+
"""
498+
cleaned_line = unicodedata.normalize("NFC", string.strip())
499+
return hash(cleaned_line)
500+
501+
451502
def decompress(
452503
source: Union[str, Path],
453504
destination: Optional[Union[Path, str]] = None,

taskcluster/kinds/corpus-merge-parallel/kind.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ tasks:
7878
upstream-artifacts:
7979
- "{dataset_sanitized}.{src_locale}.zst"
8080
- "{dataset_sanitized}.{trg_locale}.zst"
81+
- "{dataset_sanitized}.best-scores.zst"
8182
upstream-task-attributes:
8283
cleaning-type:
8384
by-cleaning-type:

tests/test_common_datasets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pipeline.common.logging import get_logger
1010
from pipeline.common.datasets import (
1111
WeakStringSet,
12+
WeakStringDict,
1213
compress,
1314
decompress,
1415
shuffle_in_temp_files,
@@ -218,6 +219,31 @@ def test_weak_string_set():
218219
assert len(unique_strings2) == 2
219220

220221

222+
def test_weak_string_dict():
223+
unique_strings_scores = WeakStringDict()
224+
unique_strings_scores["aa"] = 0.87
225+
unique_strings_scores["aa"] = 0.92
226+
unique_strings_scores["ab"] = 4.1
227+
228+
assert "aa" in unique_strings_scores
229+
assert "ab" in unique_strings_scores
230+
assert unique_strings_scores["aa"] == 0.92
231+
assert unique_strings_scores["aa"] != 0.87
232+
233+
del unique_strings_scores["aa"]
234+
assert "aa" not in unique_strings_scores
235+
236+
assert len(unique_strings_scores) == 1
237+
238+
unique_strings_scores["cdf"] = 33.2
239+
assert "cdf" in unique_strings_scores
240+
assert unique_strings_scores["cdf"] == 33.2
241+
unique_strings_scores["aa"] = 0.33
242+
unique_strings_scores["ab"] = 0.34
243+
unique_strings_scores["aa"] = 0.01
244+
assert unique_strings_scores["aa"] == 0.01
245+
246+
221247
@pytest.mark.parametrize("suffix", ["zst", "gz"])
222248
@pytest.mark.parametrize("remove_or_keep", ["remove", "keep"])
223249
def test_compress(suffix: str, remove_or_keep: str):

0 commit comments

Comments
 (0)