Skip to content

Allow mixing for pretokenized data. #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class RawFileType(enum.Enum):
ZSTD_JSONL_COMPRESSED = 2
GZIP_JSONL_COMPRESSED = 3
TAR = 4
TAR_PRETOK = 5
UNKNOWN = -1


Expand Down Expand Up @@ -118,22 +119,38 @@ def tar_reader(fh: BinaryIO, content_key: str):
"""
content_key: where in the tarfile to find the text/tokens. Options:
"txt" - read text file as string
"json" - read json file
"json:key" - read json[key] as string
"json.gz" - same as json, but also gzipped
"json.gz:key" - same as json.gz, but also gzipped
"npy" - read numpy array as tokens
"""
# TODO(gsmyrnis) - I think some of the modes (namely npy) are not clean on whether they are still useful - consider
# removing them in the future.
content_ext = content_key.split(":")[0]
buffer = io.BytesIO(fh.read())
with tarfile.open(fileobj=buffer, mode="r") as tar:
samples = []
for member in tar.getmembers():
if member.isfile() and member.name.endswith(f".{content_ext}"):
with tar.extractfile(member) as fileobj:
if fileobj: # Ensure fileobj is not None
if content_ext == "txt":
content = fileobj.read().decode("utf-8")
elif content_ext == "json":
json_dict, json_key = json.load(fileobj), content_key.split(":")[1]
content = json_dict[json_key]
json_data = json.load(fileobj)
if isinstance(json_data, dict):
json_key = content_key.split(":")[1]
content = json_data[json_key]
else:
content = json_data
elif content_ext == "json.gz":
with gzip.open(fileobj, "rb") as fileobj_unzip:
json_data = json.load(fileobj_unzip)
if isinstance(json_data, dict):
json_key = content_key.split(":")[1]
content = json_data[json_key]
else:
content = json_data
elif content_ext == "npy":
token_array = np.load(io.BytesIO(fileobj.read()), allow_pickle=True)
content = token_array.reshape(-1).tolist()
Expand Down Expand Up @@ -234,7 +251,7 @@ def _flush_buffer(self, folder, counter):
tokens = [int(x) for x in self.buffer[i]["tokens"]]
token_count += len(tokens)
json_string = json.dumps(tokens)
uid = hashlib.md5(json_string.encode()).hexdigest()
uid = f"{tar_index_str}_{i:0{digits}}"
sample = {"__key__": uid, "json.gz": json_string}
sink.write(sample)
bio.seek(0)
Expand All @@ -256,6 +273,7 @@ def preprocess(
do_sample: bool = False,
sources: enum.Enum = None,
source_counter: GlobalCounter = None,
pretok_tars: bool = False,
):
tokenizer_fn, vocab_size = tokenizer
rng = random.Random(hash(key) + seed)
Expand All @@ -273,8 +291,11 @@ def preprocess(
pbar = tqdm(file_reader(fh), mininterval=10)
pbar.set_description(key)
for string in pbar:
tokens = tokenizer_fn(string)
tokens.append(EOT)
if file_type == RawFileType.TAR and pretok_tars:
tokens = string
else:
tokens = tokenizer_fn(string)
tokens.append(EOT)
buffer += tokens
while len(buffer) >= seqlen:
if do_sample:
Expand Down Expand Up @@ -308,7 +329,9 @@ def preprocess(
return []


def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=None, source_counters=None):
def process_keys(
data, tokenizer, seqlen, seed, content_key, do_sample, pretok_tars, sources=None, source_counters=None
):
path = data["path"]

if path.startswith("s3"):
Expand Down Expand Up @@ -337,6 +360,7 @@ def process_keys(data, tokenizer, seqlen, seed, content_key, do_sample, sources=
do_sample=do_sample,
sources=sources,
source_counter=source_counter,
pretok_tars=pretok_tars,
)

# Ensure that all operations on the file handle are done within this block
Expand Down Expand Up @@ -569,8 +593,14 @@ def main(args):
"--ray_dashboard_host", type=str, default="127.0.0.1"
) # default is localhost; for slurm jobs do 0.0.0.0
parser.add_argument("--suffixes", nargs="+", default=[".json", ".jsonl", ".zst", ".zstd", ".tar", ".gz"])
parser.add_argument("--pretok_tars", action="store_true", help="Assume tars contain pretokenized data.")

args = parser.parse_args(args)

assert not args.pretok_tars or args.suffixes == [
".tar"
], "Currently mixing with tokenized and untokenized data at the same time is not supported."

if args.do_sample:
Sources, SAMPLING_FREQUENCIES = load_from_yaml(args.default_dataset_yaml)
logger.info(f"SOURCES:\n {Sources}")
Expand Down Expand Up @@ -612,6 +642,7 @@ def main(args):
input_paths = input_paths[: args.subset]
if args.subfraction is not None:
input_paths = input_paths[: int(args.subfraction * len(input_paths))]

print("Files considered: \n", input_paths)
print(f"num files ={len(input_paths)}")
num_files = len(input_paths)
Expand Down Expand Up @@ -650,6 +681,7 @@ def main(args):
seed=args.seed,
content_key=content_key,
do_sample=args.do_sample,
pretok_tars=args.pretok_tars,
sources=Sources,
source_counters=source_counters,
)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,37 @@ def test_tokenize_shuffle_local_read_local_write():
total += len(x["json.gz"])
assert total == NUM_TOKENS
assert exit_value == 0


def test_tokenize_shuffle_with_pretokenized():
content_len = 2048
NUM_TOKENS = 24508089
# download a small test json file and store at ./test_input
os.system("mkdir test_input")
os.system("mkdir test_output")
os.system(
"wget -O ./test_input/wikipedia_sample.jsonl https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/resolve/main/wikipedia_sample.jsonl"
)
# run tokenize script
exit_value_1 = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input --content_key text --seqlen {content_len} --output ./test_output/"
)
assert exit_value_1 == 0

os.system("cp -r ./test_output ./test_input/2a/")
os.system("cp -r ./test_output ./test_input/2b/")

exit_value_2 = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input/2a,./test_input/2b --content_key json.gz --seqlen {content_len} --output ./test_output/2 --pretok_tars --suffixes .tar"
)
assert exit_value_2 == 0

tars = [os.path.join("test_output/2", fname) for fname in os.listdir("test_output/2") if fname.endswith(".tar")]
total = 0
for tar in tars:
ds = wds.WebDataset(tar).decode()
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])

assert total == 2 * NUM_TOKENS