Skip to content

Commit c1cb835

Browse files
authored
Write Tokenized Data Sizes as metadata (#2431)
Writes out the token size info alongside the tokenized data itself (request from https://discord.com/channels/1354881461060243556/1366632114316906506/1458962443542724785). This doesn't help for already tokenized data, but means moving forward that reasonable stats will live alongside the data itself so it can be accessed easily to compute things like epochs.
1 parent 96fae76 commit c1cb835

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

lib/marin/src/marin/processing/tokenize/tokenize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121

2222
import abc
2323
import dataclasses
24+
import json
2425
import logging
2526
import os
2627
import re
2728
from collections.abc import Iterator, Sequence
2829
from typing import Any
2930

31+
import fsspec
32+
3033
import draccus
3134
from fray.job.context import JobContext
3235
import humanfriendly
@@ -367,6 +370,23 @@ def run_pipeline(paths: list[str], split_name: str) -> None:
367370
shard_cache_paths=shard_paths, output_path=prefix, exemplar=exemplar, context=cluster_ctx
368371
)
369372

373+
# Aggregate token counts from shard stats
374+
total_tokens = 0
375+
total_elements = 0
376+
for shard_path in shard_paths:
377+
stats_path = f"{shard_path}/.stats.json"
378+
with fsspec.open(stats_path) as f:
379+
stats = json.load(f)
380+
total_tokens += stats.get("token_count", 0)
381+
total_elements += stats.get("num_rows", 0)
382+
383+
stats_path = os.path.join(prefix, ".stats.json")
384+
logger.info(
385+
f"Writing total token count ({total_tokens:,}) and element count ({total_elements:,}) to {stats_path}"
386+
)
387+
with fsspec.open(stats_path, "w") as f:
388+
json.dump({"total_tokens": total_tokens, "total_elements": total_elements}, f)
389+
370390
if train_paths:
371391
run_pipeline(train_paths, "train")
372392

lib/zephyr/src/zephyr/writers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from dataclasses import asdict, is_dataclass
2020
import itertools
21+
import json
2122
import os
2223
from collections.abc import Iterable
2324
from contextlib import contextmanager
@@ -252,21 +253,28 @@ def write_levanter_cache(records: Iterable[dict[str, Any]], output_path: str, me
252253
try:
253254
exemplar = next(iter(records))
254255
except StopIteration:
255-
return {"path": output_path, "count": 0}
256+
return {"path": output_path, "count": 0, "token_count": 0}
256257

257258
count = 1
259+
token_count = len(exemplar.get("input_ids", []))
258260
with atomic_rename(output_path) as tmp_path:
259261
with SerialCacheWriter(tmp_path, exemplar, shard_name=output_path, metadata=CacheMetadata(metadata)) as writer:
260262
writer.write_batch([exemplar])
261263
for batch in batchify(records):
262264
writer.write_batch(batch)
263265
count += len(batch)
266+
for record in batch:
267+
token_count += len(record.get("input_ids", []))
264268

265269
# write success sentinel
266270
with fsspec.open(f"{output_path}/.success", "w") as f:
267271
f.write("")
268272

269-
return {"path": output_path, "count": count}
273+
# write stats for aggregation
274+
with fsspec.open(f"{output_path}/.stats.json", "w") as f:
275+
json.dump({"count": count, "token_count": token_count}, f)
276+
277+
return {"path": output_path, "count": count, "token_count": token_count}
270278

271279

272280
def write_binary_file(records: Iterable[bytes], output_path: str) -> dict:

0 commit comments

Comments
 (0)