@@ -25,7 +25,7 @@ def test_tokenize_shuffle_simple():
25
25
for x in ds :
26
26
assert len (x ["json.gz" ]) == content_len + 1
27
27
total += len (x ["json.gz" ])
28
- # assert total == NUM_TOKENS
28
+ assert total == NUM_TOKENS
29
29
30
30
with open ("test_output/manifest.jsonl" , "rb" ) as f :
31
31
out = f .read ()
@@ -57,7 +57,7 @@ def test_tokenize_shuffle_tar(content_key, NUM_TOKENS):
57
57
58
58
def test_tokenize_shuffle_simple_do_sample ():
59
59
content_len = 2048
60
- NUM_TOKENS = 32784
60
+ NUM_TOKENS = 86058
61
61
exit_value = os .system (
62
62
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input s3://dcnlp-west-test/tokenize_shuffle_test/C4_V3_tiny/ --content_key content --output test_output/ --seqlen { content_len } --do_sample"
63
63
)
@@ -67,7 +67,12 @@ def test_tokenize_shuffle_simple_do_sample():
67
67
for x in ds :
68
68
assert len (x ["json.gz" ]) == content_len + 1
69
69
total += len (x ["json.gz" ])
70
- assert total == NUM_TOKENS
70
+
71
+ # The sampling prob is 1.037142857 for the C4 source. This means that we will see all tokens at least once. For
72
+ # error at most 1e-4, we will need an error of 13950 tokens (by Chernoff bounds).
73
+ # TODO(gsmyrnis): Improve this.
74
+ assert total <= 1.037142857 * NUM_TOKENS + 13950
75
+ assert total >= 1.037142857 * NUM_TOKENS - 13950
71
76
72
77
73
78
@pytest .mark .s3
@@ -242,7 +247,7 @@ def test_mixing_sampling(generation_length):
242
247
243
248
# Source b is sampled with probability 0.5, so the number of documents from source b follows Bin(10000, 0.5).
244
249
# Via (multiplicative) Chernoff bounds, for margin delta the error probability is 2 * exp(-delta**2 * mu / 3)
245
- # In this case for error probability <= 1e-4, we need delta * mu = sqrt(-3 * ln(0.5e-10 ) / mu) * mu ~= 386
250
+ # In this case for error probability <= 1e-4, we need delta * mu = sqrt(-3 * ln(0.5e-4 ) / mu) * mu ~= 386
246
251
# TODO (gsmyrnis): I think you can get a better bound here.
247
252
mixing_error = 386
248
253
assert total_b <= (0.5 * docs_b + mixing_error ) * generation_length
0 commit comments