Skip to content

Commit 0f77f45

Browse files
Fix duplicate file names.
1 parent 71c3369 commit 0f77f45

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

open_lm/datapreprocess/ray/tokenize_shuffle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _flush_buffer(self, folder, counter):
251251
tokens = [int(x) for x in self.buffer[i]["tokens"]]
252252
token_count += len(tokens)
253253
json_string = json.dumps(tokens)
254-
uid = hashlib.md5(json_string.encode()).hexdigest()
254+
uid = f"{tar_index_str}_{i:0{digits}}"
255255
sample = {"__key__": uid, "json.gz": json_string}
256256
sink.write(sample)
257257
bio.seek(0)

tests/test_tokenize_shuffle.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,20 @@ def test_tokenize_shuffle_with_pretokenized():
132132
)
133133
assert exit_value_1 == 0
134134

135-
os.system("mkdir test_input_2a")
136-
os.system("mkdir test_input_2b")
137-
os.system("cp -r ./test_output/00000001.tar ./test_input_2a/")
138-
os.system("cp -r ./test_output/00000002.tar ./test_input_2b/")
139-
os.system("mkdir test_output_2")
135+
os.system("cp -r ./test_output ./test_input/2a/")
136+
os.system("cp -r ./test_output ./test_input/2b/")
140137

141138
exit_value_2 = os.system(
142-
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input_2a,./test_input_2b --content_key json.gz --seqlen {content_len} --output ./test_output_2 --pretok_tars --suffixes .tar"
139+
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input/2a,./test_input/2b --content_key json.gz --seqlen {content_len} --output ./test_output/2 --pretok_tars --suffixes .tar"
143140
)
144141
assert exit_value_2 == 0
145142

146-
tars = [os.path.join("test_output_2", fname) for fname in os.listdir("test_output_2") if fname.endswith(".tar")]
143+
tars = [os.path.join("test_output/2", fname) for fname in os.listdir("test_output/2") if fname.endswith(".tar")]
147144
total = 0
148145
for tar in tars:
149146
ds = wds.WebDataset(tar).decode()
150147
for x in ds:
151148
assert len(x["json.gz"]) == content_len + 1
152149
total += len(x["json.gz"])
153150

154-
os.system("rm -rf test_input_2a")
155-
os.system("rm -rf test_input_2b")
156-
os.system("rm -rf test_output_2")
157-
158-
assert total == NUM_TOKENS
151+
assert total == 2 * NUM_TOKENS

0 commit comments

Comments
 (0)