1414import logging
1515import os
1616import re
17+ import time
1718from collections .abc import Iterator , Sequence
1819
1920import draccus
@@ -252,14 +253,33 @@ def _tokenize_batches(*, config: TokenizeConfig | HfTokenizeConfig, batches: Ite
252253
253254 batch_count = 0
254255 record_count = 0
256+ token_count = 0
257+ start_time = time .monotonic ()
258+
255259 for batch in batches :
256260 batch_count += 1
257261 for record in batch_processor (batch ):
258262 record_count += 1
263+ token_count += len (record .get ("input_ids" , []))
259264 yield record
260265 if batch_count % 100 == 0 :
261- logger .info ("Tokenized %d batches, %d records so far" , batch_count , record_count )
262- logger .info ("Tokenization done: %d batches, %d records total" , batch_count , record_count )
266+ elapsed = time .monotonic () - start_time
267+ tok_per_sec = token_count / elapsed if elapsed > 0 else 0
268+ doc_per_sec = record_count / elapsed if elapsed > 0 else 0
269+ avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
270+ logger .info (
271+ f"Tokenized { batch_count :,} batches, { record_count :,} docs, { token_count :,} tokens in { elapsed :.1f} s "
272+ f"({ tok_per_sec :,.0f} tokens/s, { doc_per_sec :,.1f} docs/s, { avg_tok_per_doc :,.0f} avg tokens/doc)"
273+ )
274+
275+ elapsed = time .monotonic () - start_time
276+ tok_per_sec = token_count / elapsed if elapsed > 0 else 0
277+ doc_per_sec = record_count / elapsed if elapsed > 0 else 0
278+ avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
279+ logger .info (
280+ f"Tokenization done: { batch_count :,} batches, { record_count :,} docs, { token_count :,} tokens in { elapsed :.1f} s "
281+ f"({ tok_per_sec :,.0f} tokens/s, { doc_per_sec :,.1f} docs/s, { avg_tok_per_doc :,.0f} avg tokens/doc)"
282+ )
263283
264284
265285def tokenize (config : TokenizeConfigBase ):
@@ -311,16 +331,23 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
311331 )
312332 return
313333
334+ pipeline_start = time .monotonic ()
335+
314336 # Use local backend for lightweight file stats - no remote workers needed
337+ filescan_start = time .monotonic ()
315338 with ZephyrContext (client = LocalClient (), max_workers = 8 , name = "tokenize-filescan" ) as local_ctx :
316339 file_stats = list (
317340 local_ctx .execute (
318341 Dataset .from_list (paths ).map (lambda path : {"filename" : path , "size" : fsspec_size (path )}),
319342 verbose = False ,
320343 )
321344 )
345+ total_input_bytes = sum (f ["size" ] for f in file_stats )
322346 file_groups = list (_bundle_files_by_size (file_stats , config .window_size_bytes ))
323- logger .info (f"Grouped { len (paths )} files into { len (file_groups )} groups by size." )
347+ logger .info (
348+ f"Grouped { len (paths ):,} files ({ total_input_bytes / 1e9 :.2f} GB) into { len (file_groups ):,} groups "
349+ f"in { time .monotonic () - filescan_start :.1f} s."
350+ )
324351
325352 ds = Dataset .from_list (file_groups ).flat_map (lambda file_list : file_list ).flat_map (load_file )
326353
@@ -337,7 +364,9 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
337364 # Broadcast the tokenizer to all workers via ZephyrContext
338365 ctx .put ("tokenizer" , transformers .AutoTokenizer .from_pretrained (config .tokenizer ))
339366
367+ tokenize_start = time .monotonic ()
340368 shard_paths = ctx .execute (temp_shards )
369+ tokenize_elapsed = time .monotonic () - tokenize_start
341370
342371 logger .info ("Computing exemplar for cache consolidation" )
343372 exemplar = ctx .execute (
@@ -348,8 +377,10 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
348377 verbose = False ,
349378 )[0 ]
350379
351- logger .info (f"Tokenization complete, consolidating { len (shard_paths )} shards into { prefix } " )
380+ consolidate_start = time .monotonic ()
381+ logger .info (f"Consolidating { len (shard_paths )} shards into { prefix } " )
352382 consolidate_shard_caches (shard_cache_paths = shard_paths , output_path = prefix , exemplar = exemplar )
383+ consolidate_elapsed = time .monotonic () - consolidate_start
353384
354385 # Aggregate token counts from shard stats
355386 total_tokens = 0
@@ -362,12 +393,19 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
362393 total_elements += stats .get ("num_rows" , 0 )
363394
364395 stats_path = os .path .join (prefix , ".stats.json" )
365- logger .info (
366- f"Writing total token count ({ total_tokens :,} ) and element count ({ total_elements :,} ) to { stats_path } "
367- )
368396 with fsspec .open (stats_path , "w" ) as f :
369397 json .dump ({"total_tokens" : total_tokens , "total_elements" : total_elements }, f )
370398
399+ pipeline_elapsed = time .monotonic () - pipeline_start
400+ overall_tok_per_sec = total_tokens / tokenize_elapsed if tokenize_elapsed > 0 else 0
401+ overall_doc_per_sec = total_elements / tokenize_elapsed if tokenize_elapsed > 0 else 0
402+ logger .info (
403+ f"{ split_name } pipeline complete: { total_elements :,} docs, { total_tokens :,} tokens "
404+ f"in { pipeline_elapsed :.1f} s (tokenize: { tokenize_elapsed :.1f} s at { overall_tok_per_sec :,.0f} tokens/s "
405+ f"{ overall_doc_per_sec :,.1f} docs/s, consolidate: { consolidate_elapsed :.1f} s). "
406+ f"Wrote stats to { stats_path } "
407+ )
408+
371409 with ZephyrContext (
372410 resources = ResourceConfig (ram = "16g" , disk = "16g" ),
373411 max_workers = min (128 , len (train_paths ) + len (validation_paths )),
0 commit comments