Skip to content

Commit 84e4bd4

Browse files
Fix some old tests.
1 parent 8f3bbb9 commit 84e4bd4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/test_tokenize_shuffle.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_tokenize_shuffle_simple():
2525
for x in ds:
2626
assert len(x["json.gz"]) == content_len + 1
2727
total += len(x["json.gz"])
28-
# assert total == NUM_TOKENS
28+
assert total == NUM_TOKENS
2929

3030
with open("test_output/manifest.jsonl", "rb") as f:
3131
out = f.read()
@@ -57,7 +57,7 @@ def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
5757

5858
def test_tokenize_shuffle_simple_do_sample():
5959
content_len = 2048
60-
NUM_TOKENS = 32784
60+
NUM_TOKENS = 86058
6161
exit_value = os.system(
6262
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen {content_len} --do_sample"
6363
)
@@ -67,7 +67,12 @@ def test_tokenize_shuffle_simple_do_sample():
6767
for x in ds:
6868
assert len(x["json.gz"]) == content_len + 1
6969
total += len(x["json.gz"])
70-
assert total == NUM_TOKENS
70+
71+
# The sampling prob is 1.037142857 for the C4 source. This means that we will see all tokens at least once. For
72+
# error at most 1e-4, we will need an error of 13950 tokens (by Chernoff bounds).
73+
# TODO(gsmyrnis): Improve this.
74+
assert total <= 1.037142857 * NUM_TOKENS + 13950
75+
assert total >= 1.037142857 * NUM_TOKENS - 13950
7176

7277

7378
@pytest.mark.s3
@@ -242,7 +247,7 @@ def test_mixing_sampling(generation_length):
242247

243248
# Source b is sampled with probability 0.5, so the number of documents from source b follows Bin(10000, 0.5).
244249
# Via (multiplicative) Chernoff bounds, for margin delta the error probability is 2 * exp(-delta**2 * mu / 3)
245-
# In this case for error probability <= 1e-4, we need delta * mu = sqrt(-3 * ln(0.5e-10) / mu) * mu ~= 386
250+
# In this case for error probability <= 1e-4, we need delta * mu = sqrt(-3 * ln(0.5e-4) / mu) * mu ~= 386
246251
# TODO (gsmyrnis): I think you can get a better bound here.
247252
mixing_error = 386
248253
assert total_b <= (0.5 * docs_b + mixing_error) * generation_length

0 commit comments

Comments
 (0)