Skip to content

Commit 28b1984

Browse files
authored
Improve logging in tokenize (#2908)
* Entire-Checkpoint: 0117fe3150b1 * re #2829 Changes: * improve stats logging in tokenize, including things like logging tokens/s CC: @rjpower
1 parent 6118297 commit 28b1984

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

lib/marin/src/marin/processing/tokenize/tokenize.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import os
1616
import re
17+
import time
1718
from collections.abc import Iterator, Sequence
1819

1920
import draccus
@@ -252,14 +253,33 @@ def _tokenize_batches(*, config: TokenizeConfig | HfTokenizeConfig, batches: Ite
252253

253254
batch_count = 0
254255
record_count = 0
256+
token_count = 0
257+
start_time = time.monotonic()
258+
255259
for batch in batches:
256260
batch_count += 1
257261
for record in batch_processor(batch):
258262
record_count += 1
263+
token_count += len(record.get("input_ids", []))
259264
yield record
260265
if batch_count % 100 == 0:
261-
logger.info("Tokenized %d batches, %d records so far", batch_count, record_count)
262-
logger.info("Tokenization done: %d batches, %d records total", batch_count, record_count)
266+
elapsed = time.monotonic() - start_time
267+
tok_per_sec = token_count / elapsed if elapsed > 0 else 0
268+
doc_per_sec = record_count / elapsed if elapsed > 0 else 0
269+
avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
270+
logger.info(
271+
f"Tokenized {batch_count:,} batches, {record_count:,} docs, {token_count:,} tokens in {elapsed:.1f}s "
272+
f"({tok_per_sec:,.0f} tokens/s, {doc_per_sec:,.1f} docs/s, {avg_tok_per_doc:,.0f} avg tokens/doc)"
273+
)
274+
275+
elapsed = time.monotonic() - start_time
276+
tok_per_sec = token_count / elapsed if elapsed > 0 else 0
277+
doc_per_sec = record_count / elapsed if elapsed > 0 else 0
278+
avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
279+
logger.info(
280+
f"Tokenization done: {batch_count:,} batches, {record_count:,} docs, {token_count:,} tokens in {elapsed:.1f}s "
281+
f"({tok_per_sec:,.0f} tokens/s, {doc_per_sec:,.1f} docs/s, {avg_tok_per_doc:,.0f} avg tokens/doc)"
282+
)
263283

264284

265285
def tokenize(config: TokenizeConfigBase):
@@ -311,16 +331,23 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
311331
)
312332
return
313333

334+
pipeline_start = time.monotonic()
335+
314336
# Use local backend for lightweight file stats - no remote workers needed
337+
filescan_start = time.monotonic()
315338
with ZephyrContext(client=LocalClient(), max_workers=8, name="tokenize-filescan") as local_ctx:
316339
file_stats = list(
317340
local_ctx.execute(
318341
Dataset.from_list(paths).map(lambda path: {"filename": path, "size": fsspec_size(path)}),
319342
verbose=False,
320343
)
321344
)
345+
total_input_bytes = sum(f["size"] for f in file_stats)
322346
file_groups = list(_bundle_files_by_size(file_stats, config.window_size_bytes))
323-
logger.info(f"Grouped {len(paths)} files into {len(file_groups)} groups by size.")
347+
logger.info(
348+
f"Grouped {len(paths):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups "
349+
f"in {time.monotonic() - filescan_start:.1f}s."
350+
)
324351

325352
ds = Dataset.from_list(file_groups).flat_map(lambda file_list: file_list).flat_map(load_file)
326353

@@ -337,7 +364,9 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
337364
# Broadcast the tokenizer to all workers via ZephyrContext
338365
ctx.put("tokenizer", transformers.AutoTokenizer.from_pretrained(config.tokenizer))
339366

367+
tokenize_start = time.monotonic()
340368
shard_paths = ctx.execute(temp_shards)
369+
tokenize_elapsed = time.monotonic() - tokenize_start
341370

342371
logger.info("Computing exemplar for cache consolidation")
343372
exemplar = ctx.execute(
@@ -348,8 +377,10 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
348377
verbose=False,
349378
)[0]
350379

351-
logger.info(f"Tokenization complete, consolidating {len(shard_paths)} shards into {prefix}")
380+
consolidate_start = time.monotonic()
381+
logger.info(f"Consolidating {len(shard_paths)} shards into {prefix}")
352382
consolidate_shard_caches(shard_cache_paths=shard_paths, output_path=prefix, exemplar=exemplar)
383+
consolidate_elapsed = time.monotonic() - consolidate_start
353384

354385
# Aggregate token counts from shard stats
355386
total_tokens = 0
@@ -362,12 +393,19 @@ def run_pipeline(ctx: ZephyrContext, paths: list[str], split_name: str) -> None:
362393
total_elements += stats.get("num_rows", 0)
363394

364395
stats_path = os.path.join(prefix, ".stats.json")
365-
logger.info(
366-
f"Writing total token count ({total_tokens:,}) and element count ({total_elements:,}) to {stats_path}"
367-
)
368396
with fsspec.open(stats_path, "w") as f:
369397
json.dump({"total_tokens": total_tokens, "total_elements": total_elements}, f)
370398

399+
pipeline_elapsed = time.monotonic() - pipeline_start
400+
overall_tok_per_sec = total_tokens / tokenize_elapsed if tokenize_elapsed > 0 else 0
401+
overall_doc_per_sec = total_elements / tokenize_elapsed if tokenize_elapsed > 0 else 0
402+
logger.info(
403+
f"{split_name} pipeline complete: {total_elements:,} docs, {total_tokens:,} tokens "
404+
f"in {pipeline_elapsed:.1f}s (tokenize: {tokenize_elapsed:.1f}s at {overall_tok_per_sec:,.0f} tokens/s "
405+
f"{overall_doc_per_sec:,.1f} docs/s, consolidate: {consolidate_elapsed:.1f}s). "
406+
f"Wrote stats to {stats_path}"
407+
)
408+
371409
with ZephyrContext(
372410
resources=ResourceConfig(ram="16g", disk="16g"),
373411
max_workers=min(128, len(train_paths) + len(validation_paths)),

0 commit comments

Comments
 (0)