2222from pipeline .common .datasets import (
2323 FilteringStep ,
2424 Statistics ,
25- WeakStringSet ,
25+ WeakStringDict ,
2626 shuffle_with_max_lines ,
2727)
2828from pipeline .common .downloads import get_human_readable_file_size , read_lines , write_lines
@@ -58,17 +58,24 @@ def log_dataset(location: str):
5858 logger .info (f"Reading dataset { location } " )
5959
6060
61+ def dummy_score_generator ():
62+ for i in iter (int , 1 ):
63+ yield "1.0"
64+
65+
6166class DeduplicateCorpus :
6267 def __init__ (
6368 self ,
6469 datasets_src : list [Path ],
6570 datasets_trg : list [Path ],
71+ datasets_scores : list [Path ],
6672 src_outpath : Path ,
6773 trg_outpath : Path ,
6874 stats : FilteringStatistics ,
6975 ) -> None :
7076 self .datasets_src : list [Path ] = datasets_src
7177 self .datasets_trg : list [Path ] = datasets_trg
78+ self .datasets_scores : list [Path ] = datasets_scores
7279 self .src_outpath : Path = src_outpath
7380 self .trg_outpath : Path = trg_outpath
7481 self .stats : FilteringStatistics = stats
@@ -105,30 +112,63 @@ def run(
105112 stats .final_truncated .kept = stats .parallel_corpus .kept
106113 stats .final_truncated .visited = stats .parallel_corpus .kept
107114
108- def yield_lines_tuple (self , stack : ExitStack ) -> Generator [tuple [str , str ], None , None ]:
109- strings_seen = WeakStringSet ()
110- stats = self .stats
115+ def on_enter_location (self , location ):
116+ log_dataset (location )
117+ self .dataset_stats = self .stats .add_parallel_dataset (location )
118+
119+ def _yield_lines (self , stack : ExitStack , add_stats : bool = False ):
120+ if add_stats :
121+ enter_location_func = self .on_enter_location
122+ else :
123+ enter_location_func = log_dataset
124+
111125 src_lines : Generator [str , None , None ] = stack .enter_context (
112- read_lines (self .datasets_src , on_enter_location = self . on_enter_location )
126+ read_lines (self .datasets_src , on_enter_location = enter_location_func )
113127 )
114128 trg_lines : Generator [str , None , None ] = stack .enter_context (
115129 read_lines (self .datasets_trg , on_enter_location = log_dataset )
116130 )
131+ if self .datasets_scores == []:
132+ logger .info ("No scores found, deduping without score" )
133+ scores_lines = dummy_score_generator ()
134+ else :
135+ scores_lines : Generator [str , None , None ] = stack .enter_context (
136+ read_lines (self .datasets_scores , on_enter_location = log_dataset )
137+ )
117138
118- for src_line , trg_line in zip (src_lines , trg_lines ):
119- # No separator is needed as the newline is included.
120- line = src_line + trg_line
139+ for i , (src_line , trg_line , score_line ) in enumerate (
140+ zip (src_lines , trg_lines , scores_lines )
141+ ):
142+ try :
143+ score = float (score_line )
144+ except ValueError as e :
145+ raise ValueError (f"Could not parse score in line { i } " ) from e
121146
122- if line in strings_seen :
123- stats .parallel_corpus .filtered += 1
124- self .dataset_stats .filtered += 1
125- else :
147+ yield src_line , trg_line , score
148+
149+ def yield_lines_tuple (self , stack : ExitStack ) -> Generator [tuple [str , str ], None , None ]:
150+ strings_seen = WeakStringDict ()
151+ stats = self .stats
152+ for src_line , trg_line , score in self ._yield_lines (stack ):
153+ # store all possible targets
154+ # for all the sentence pairs that have the same target, keep the best score
155+ if trg_line not in strings_seen or strings_seen [trg_line ] < score :
156+ strings_seen [trg_line ] = score
157+
158+ for src_line , trg_line , score in self ._yield_lines (stack , add_stats = True ):
159+ # When a target has the same score as stored, therefore the best score
160+ # we keep it
161+ if trg_line in strings_seen and strings_seen [trg_line ] == score :
126162 stats .parallel_corpus .kept += 1
127163 self .dataset_stats .kept += 1
128-
129- strings_seen .add (line )
164+ # the item is removed from the dict to avoid keeping two sentence pairs
165+ # that have the same target AND the same score
166+ del strings_seen [trg_line ]
130167
131168 yield src_line , trg_line
169+ else :
170+ stats .parallel_corpus .filtered += 1
171+ self .dataset_stats .filtered += 1
132172
133173 def yield_lines_string (self , stack : ExitStack ) -> Generator [str , None , None ]:
134174 for src_line , trg_line in self .yield_lines_tuple (stack ):
@@ -139,10 +179,6 @@ def yield_lines_string(self, stack: ExitStack) -> Generator[str, None, None]:
139179 else :
140180 yield f"{ src_line } \t { trg_line } "
141181
142- def on_enter_location (self , location ):
143- log_dataset (location )
144- self .dataset_stats = self .stats .add_parallel_dataset (location )
145-
146182
147183def sample_corpus (
148184 artifacts : Path , name : str , sample_size : int , src_outpath : Path , trg_outpath : Path
@@ -204,24 +240,43 @@ def get_datasets(src: str, trg: str, datasets_glob: str):
204240 dataset_paths : list [str ] = glob (datasets_glob )
205241 datasets_src : list [Path ] = []
206242 datasets_trg : list [Path ] = []
243+ datasets_scores : list [Path ] = []
207244 dataset_paths .sort ()
208245
209246 total_corpus_bytes = 0
210247
211248 for dataset in dataset_paths :
212249 path = Path (dataset )
250+ countbytes = True
213251 if dataset .endswith (f"{ src } .zst" ):
214252 datasets_src .append (path )
215253 elif dataset .endswith (f"{ trg } .zst" ):
216254 datasets_trg .append (path )
255+ elif dataset .endswith (".best-scores.zst" ):
256+ datasets_scores .append (path )
257+ countbytes = False
217258 else :
218259 raise Exception (f"Dataset does not match naming scheme: { dataset } " )
219260
220- formatted_size , bytes = get_human_readable_file_size (path )
221- logger .info (f" - { path } ({ formatted_size } )" )
222- total_corpus_bytes += bytes
261+ # Do not count bytes of the scores
262+ if countbytes :
263+ formatted_size , bytes = get_human_readable_file_size (path )
264+ logger .info (f" - { path } ({ formatted_size } )" )
265+ total_corpus_bytes += bytes
266+
267+ # Fail if different amount of files per dataset
268+ # but do not file if no .scores are provided (when running for devsets)
269+ if (
270+ len (datasets_src ) != len (datasets_trg ) or len (datasets_src ) != len (datasets_scores )
271+ ) and datasets_scores != []:
272+ logger .info (datasets_src )
273+ logger .info (datasets_trg )
274+ logger .info (datasets_scores )
275+ raise Exception (
276+ f"Number of files per dataset is different src: { len (datasets_src )} trg: { len (datasets_trg )} scores: { len (datasets_scores )} "
277+ )
223278
224- return datasets_src , datasets_trg , total_corpus_bytes
279+ return datasets_src , datasets_trg , datasets_scores , total_corpus_bytes
225280
226281
227282def main () -> None :
@@ -273,7 +328,7 @@ def main() -> None:
273328
274329 args = parser .parse_args ()
275330
276- datasets_src , datasets_trg , total_corpus_bytes = get_datasets (
331+ datasets_src , datasets_trg , datasets_scores , total_corpus_bytes = get_datasets (
277332 args .src , args .trg , args .datasets_glob
278333 )
279334
@@ -291,6 +346,7 @@ def main() -> None:
291346 deduplicate_corpus = DeduplicateCorpus (
292347 datasets_src ,
293348 datasets_trg ,
349+ datasets_scores ,
294350 src_outpath ,
295351 trg_outpath ,
296352 stats ,
0 commit comments