Skip to content

Commit 9c82596

Browse files
author
The tunix Authors
committed
Merge pull request #936 from google:lance-dataset
PiperOrigin-RevId: 848295475
2 parents e165692 + 8201a54 commit 9c82596

File tree

7 files changed

+424
-111
lines changed

7 files changed

+424
-111
lines changed

.github/workflows/tpu-tests.yml

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -166,36 +166,15 @@ jobs:
166166
env:
167167
HF_TOKEN: ${{ secrets.HF_TOKEN }}
168168
run: |
169-
170-
# Download GSM8K dataset
171-
mkdir -p /tmp/grpo_test/rl/grpo/data
172-
python3 -c "
173-
from datasets import load_dataset
174-
import json
175-
176-
# Download and save GSM8K train split
177-
dataset = load_dataset('openai/gsm8k', 'main', split='train')
178-
train_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
179-
with open('/tmp/grpo_test/rl/grpo/data/gsm8k_train.json', 'w') as f:
180-
json.dump(train_data, f)
181-
182-
# Download and save GSM8K test split
183-
dataset = load_dataset('openai/gsm8k', 'main', split='test')
184-
test_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
185-
with open('/tmp/grpo_test/rl/grpo/data/gsm8k_test.json', 'w') as f:
186-
json.dump(test_data, f)
187-
188-
print('GSM8K dataset downloaded successfully')
189-
"
190-
191-
# TODO(lancewang): Re-enable this test once the segfault is fixed.
192169
# Run GRPO demo script with minimal configuration
193-
# python3 scripts/grpo_demo_llama3_qwen2.py \
194-
# --root-dir=/tmp/grpo_test \
195-
# --model-version=Qwen/Qwen2.5-0.5B-Instruct \
196-
# --num-batches=1 \
197-
# --num-test-batches=1 \
198-
# --rollout-engine=vanilla
170+
python3 scripts/grpo_demo_llama3_qwen2.py \
171+
--root-dir=/tmp/grpo_test \
172+
--num-batches=2 \
173+
--num-test-batches=1 \
174+
--global-batch-size=2 \
175+
--train-mini-batch-size=2 \
176+
--train-micro-batch-size=2 \
177+
--rollout-engine=vanilla
199178
- name: Run vllm tests
200179
env:
201180
HF_TOKEN: ${{ secrets.HF_TOKEN }}

scripts/grpo_demo_llama3_qwen2.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import qwix
4141
from tqdm.auto import tqdm
4242
import transformers
43+
from tunix.cli.utils import data as data_lib
4344
from tunix.examples.data import math_dataset
4445
from tunix.models.llama3 import model as llama_lib
4546
from tunix.models.llama3 import params as llama_params
@@ -573,37 +574,37 @@ def extract_hash_answer(text: str) -> str | None:
573574
dataset = create_dataset(
574575
args.data_source,
575576
args.dataset if args.data_source == "tfds" else LOCAL_TRAIN_DATA_DIR,
576-
args.global_batch_size,
577-
NUM_BATCHES,
577+
tokenizer=model_tokenizer,
578578
tfds_download=True,
579+
split="train",
579580
)
580581

581-
if TRAIN_FRACTION == 1.0:
582-
train_dataset = dataset.repeat(NUM_EPOCHS)
583-
val_dataset = None
584-
else:
585-
train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
586-
train_dataset = train_dataset.repeat(NUM_EPOCHS)
587-
588-
val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)
582+
train_dataset, val_dataset = data_lib.post_init_dataset(
583+
dataset,
584+
model_tokenizer,
585+
batch_size=args.global_batch_size,
586+
num_batches=NUM_BATCHES,
587+
max_prompt_length=MAX_PROMPT_LENGTH,
588+
fraction=TRAIN_FRACTION,
589+
num_epochs=NUM_EPOCHS,
590+
)
589591

590592
test_dataset = create_dataset(
591593
args.data_source,
592594
args.dataset if args.data_source == "tfds" else LOCAL_TRAIN_DATA_DIR,
593-
args.global_batch_size,
594-
NUM_TEST_BATCHES,
595+
tokenizer=model_tokenizer,
595596
tfds_download=True,
597+
split="test",
596598
)
597599

598-
print(
599-
f"train_dataset size: {len(train_dataset)}, val_dataset size:"
600-
f"{len(val_dataset) if val_dataset is not None else 0},"
601-
f"test_dataset size: {len(test_dataset)}"
600+
test_dataset, _ = data_lib.post_init_dataset(
601+
test_dataset,
602+
model_tokenizer,
603+
batch_size=args.global_batch_size,
604+
num_batches=NUM_TEST_BATCHES,
605+
max_prompt_length=MAX_PROMPT_LENGTH,
602606
)
603607

604-
for ele in train_dataset[:1]:
605-
pprint.pprint(ele)
606-
607608
MODEL_CONFIG = {
608609
"meta-llama/Llama-3.2-1B-Instruct": llama_lib.ModelConfig.llama3p2_1b,
609610
"meta-llama/Llama-3.2-3B-Instruct": llama_lib.ModelConfig.llama3p2_3b,
@@ -774,8 +775,7 @@ def check_answer(prompts, completions, answer, **kargs): # pylint: disable=unus
774775
responses = completions
775776

776777
extracted_responses = [
777-
guess.group(1) if (guess := match_format.search(r)) is not None else None
778-
for r in responses
778+
(m[-1] if (m := match_numbers.findall(r)) else None) for r in responses
779779
]
780780

781781
scores = []
@@ -808,7 +808,8 @@ def check_answer(prompts, completions, answer, **kargs): # pylint: disable=unus
808808

809809

810810
match_numbers = re.compile(
811-
rf"{solution_start}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
811+
rf"{solution_start}.*?([+-]?(?:\d[\d,]*)(?:\.\d+)?|[+-]?\.\d+)",
812+
flags=re.MULTILINE | re.DOTALL,
812813
)
813814
match_numbers.findall(f"{solution_start} 0.34 {solution_end}")
814815

@@ -829,8 +830,7 @@ def check_numbers(prompts, completions, answer, **kargs): # pylint: disable=unu
829830
responses = completions
830831

831832
extracted_responses = [
832-
guess.group(1) if (guess := match_numbers.search(r)) is not None else None
833-
for r in responses
833+
(m[-1] if (m := match_numbers.findall(r)) else None) for r in responses
834834
]
835835

836836
scores = []
@@ -846,8 +846,8 @@ def check_numbers(prompts, completions, answer, **kargs): # pylint: disable=unu
846846
continue
847847
# Convert to numbers
848848
try:
849-
true_answer = float(true_answer.strip())
850-
guess = float(guess.strip())
849+
true_answer = float(true_answer.replace(",", "").strip())
850+
guess = float(guess.replace(",", "").strip())
851851
scores.append(1.5 if guess == true_answer else 0.0)
852852
except Exception: # pylint: disable=broad-except
853853
scores.append(0)
@@ -938,20 +938,20 @@ def evaluate(
938938
partially_corr_per_question = 0
939939
corr_format_per_question = 0
940940
for response in multiple_call_response:
941-
extracted_response = (
942-
guess.group(1)
943-
if (guess := match_numbers.search(response)) is not None
944-
else "-1000000"
945-
)
941+
# Grab the last matched number from this response (not a generator)
942+
matches = match_numbers.findall(response)
943+
extracted_response = matches[-1] if matches else "-1000000"
946944
try:
947-
if float(extracted_response.strip()) == float(answer.strip()):
945+
response_num = float(extracted_response.replace(",", "").strip())
946+
answer_num = float(answer.replace(",", "").strip())
947+
if response_num == answer_num:
948948
corr_ctr_per_question += 1
949949

950-
ratio = float(extracted_response.strip()) / float(answer.strip())
950+
ratio = response_num / answer_num
951951
if ratio >= 0.9 and ratio <= 1.1:
952952
partially_corr_per_question += 1
953-
except (ValueError, ZeroDivisionError):
954-
print("SKIPPED")
953+
except (ValueError, ZeroDivisionError) as e:
954+
print(f"SKIPPED: {e}")
955955

956956
# check format
957957
if match_format.search(response) is not None:

tests/cli/utils/data_test.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for tunix.cli.utils.data.post_init_dataset."""
16+
17+
from __future__ import annotations
18+
19+
from absl.testing import absltest
20+
from tunix.cli.utils import data as data_lib
21+
22+
23+
class _FakeTokenizer:
24+
25+
def tokenize(self, text: str):
26+
# Simple tokenization: one token per whitespace-separated chunk
27+
return text.split()
28+
29+
30+
class _BaseDataset:
31+
"""Minimal dataset to mimic grain interfaces used in post_init_dataset."""
32+
33+
def __init__(self, records):
34+
self._records = list(records)
35+
36+
def __len__(self):
37+
return len(self._records)
38+
39+
def __getitem__(self, idx):
40+
if isinstance(idx, slice):
41+
return _BaseDataset(self._records[idx])
42+
return self._records[idx]
43+
44+
def filter(self, fn):
45+
return _BaseDataset([x for x in self._records if fn(x)])
46+
47+
def repeat(self, n):
48+
return _RepeatDataset(self, n)
49+
50+
def to_iter_dataset(self):
51+
return _IterDataset(self._records)
52+
53+
def map(self, fn): # Not used in tests, but kept for fidelity.
54+
return _BaseDataset([fn(x) for x in self._records])
55+
56+
57+
class _RepeatDataset:
58+
59+
def __init__(self, base: _BaseDataset, n: int):
60+
self._base = base
61+
self._n = n
62+
63+
def __len__(self):
64+
return len(self._base) * self._n
65+
66+
def to_iter_dataset(self):
67+
return _IterDataset(self._base._records * self._n)
68+
69+
70+
class _IterDataset:
71+
72+
def __init__(self, records):
73+
self._records = list(records)
74+
75+
def batch(self, batch_size: int):
76+
return _BatchedDataset(self._records, batch_size)
77+
78+
79+
class _BatchedDataset:
80+
81+
def __init__(self, records, batch_size: int):
82+
self._records = records
83+
self._batch_size = batch_size
84+
85+
def __iter__(self):
86+
for i in range(0, len(self._records), self._batch_size):
87+
yield self._records[i : i + self._batch_size]
88+
89+
90+
class PostInitDatasetTest(absltest.TestCase):
91+
92+
def test_filters_by_prompt_length(self):
93+
tokenizer = _FakeTokenizer()
94+
dataset = _BaseDataset([
95+
{"prompts": "short", "answer": 1},
96+
{"prompts": "this is too long", "answer": 2},
97+
])
98+
99+
first, second = data_lib.post_init_dataset(
100+
dataset,
101+
tokenizer=tokenizer,
102+
batch_size=2,
103+
num_batches=None,
104+
max_prompt_length=2, # only the first record should remain
105+
)
106+
107+
batches = list(first)
108+
self.assertIsNone(second)
109+
self.assertLen(batches, 1)
110+
self.assertEqual(batches[0], [{"prompts": "short", "answer": 1}])
111+
112+
def test_limits_num_batches(self):
113+
tokenizer = _FakeTokenizer()
114+
dataset = _BaseDataset(
115+
[{"prompts": f"p{i}", "answer": i} for i in range(10)]
116+
)
117+
118+
first, _ = data_lib.post_init_dataset(
119+
dataset,
120+
tokenizer=tokenizer,
121+
batch_size=3,
122+
num_batches=2, # keep at most 2 batches * 3 = 6 examples
123+
max_prompt_length=None,
124+
)
125+
126+
batches = list(first)
127+
self.assertLen(batches, 2)
128+
self.assertEqual([len(b) for b in batches], [3, 3])
129+
self.assertEqual(batches[0][0]["prompts"], "p0")
130+
self.assertEqual(batches[-1][-1]["prompts"], "p5")
131+
132+
def test_fraction_split_and_repeat(self):
133+
tokenizer = _FakeTokenizer()
134+
dataset = _BaseDataset(
135+
[{"prompts": f"p{i}", "answer": i} for i in range(8)]
136+
)
137+
138+
first, second = data_lib.post_init_dataset(
139+
dataset,
140+
tokenizer=tokenizer,
141+
batch_size=2,
142+
num_batches=None,
143+
max_prompt_length=None,
144+
fraction=0.5,
145+
num_epochs=1,
146+
)
147+
148+
first_batches = list(first)
149+
second_batches = list(second)
150+
151+
self.assertLen(first_batches, 2) # 4 items / batch_size 2
152+
self.assertLen(second_batches, 2) # remaining 4 items / batch_size 2
153+
self.assertEqual(first_batches[0][0]["prompts"], "p0")
154+
self.assertEqual(second_batches[-1][-1]["prompts"], "p7")
155+
156+
def test_num_epochs_repeats_dataset(self):
157+
tokenizer = _FakeTokenizer()
158+
dataset = _BaseDataset(
159+
[{"prompts": "p0", "answer": 0}, {"prompts": "p1", "answer": 1}]
160+
)
161+
162+
first, second = data_lib.post_init_dataset(
163+
dataset,
164+
tokenizer=tokenizer,
165+
batch_size=1,
166+
num_batches=None,
167+
max_prompt_length=None,
168+
num_epochs=3,
169+
)
170+
171+
self.assertIsNone(second)
172+
batches = list(first)
173+
# 2 items * 3 epochs = 6 batches of size 1
174+
self.assertLen(batches, 6)
175+
self.assertEqual(
176+
[b[0]["prompts"] for b in batches], ["p0", "p1", "p0", "p1", "p0", "p1"]
177+
)
178+
179+
180+
if __name__ == "__main__":
181+
absltest.main()

0 commit comments

Comments
 (0)