Skip to content

Commit f567589

Browse files
committed
feat: complete ursa math prm phase0 and phase1
1 parent 9c8877a commit f567589

File tree

2 files changed

+404
-23
lines changed

2 files changed

+404
-23
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""
2+
Prepare a LightRFT-compatible Stage 3 manifest from URSA-MATH raw data.
3+
4+
This script converts the raw `MMathCoT-1M` jsonl schema:
5+
6+
{"image_url": "...", "instruction": "...", "output": "..."}
7+
8+
into a LightRFT prompt dataset schema:
9+
10+
{
11+
"prompt": "...",
12+
"images": ["/abs/path/to/image.png"],
13+
"reference": "...",
14+
"label": "math_prm"
15+
}
16+
17+
It also performs a lightweight `PromptDatasetVL` smoke validation on the
18+
converted records so the output can be consumed directly by
19+
`examples/math_prm/train_colocate.py`.
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import argparse
25+
import json
26+
import re
27+
from collections import Counter
28+
from pathlib import Path
29+
from types import SimpleNamespace
30+
from typing import Any
31+
32+
33+
REPO_ROOT = Path(__file__).resolve().parents[2]
34+
35+
import sys
36+
37+
if str(REPO_ROOT) not in sys.path:
38+
sys.path.insert(0, str(REPO_ROOT))
39+
40+
from lightrft.datasets.prompts_dataset_vl import PromptDatasetVL
41+
42+
43+
DEFAULT_INPUT_PATH = "/home/ubuntu/URSA-MATH/datasets/URSA-MATH/MMathCoT-1M/train.jsonl"
44+
DEFAULT_IMAGE_ROOT = "/home/ubuntu/URSA-MATH/datasets/URSA-MATH/images"
45+
DEFAULT_OUTPUT_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_prm.jsonl")
46+
DEFAULT_SUMMARY_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_prm.summary.json")
47+
48+
49+
def parse_args() -> argparse.Namespace:
50+
parser = argparse.ArgumentParser(
51+
description=(
52+
"Convert URSA-MATH MMathCoT-1M raw jsonl into a LightRFT "
53+
"Stage 3 prompt manifest and validate it with PromptDatasetVL."
54+
)
55+
)
56+
parser.add_argument(
57+
"--input-path",
58+
type=str,
59+
default=DEFAULT_INPUT_PATH,
60+
help="Path to MMathCoT-1M raw train.jsonl.",
61+
)
62+
parser.add_argument(
63+
"--image-root",
64+
type=str,
65+
default=DEFAULT_IMAGE_ROOT,
66+
help="Root directory for URSA-MATH image assets.",
67+
)
68+
parser.add_argument(
69+
"--output-path",
70+
type=str,
71+
default=DEFAULT_OUTPUT_PATH,
72+
help="Output path for the converted LightRFT jsonl manifest.",
73+
)
74+
parser.add_argument(
75+
"--summary-path",
76+
type=str,
77+
default=DEFAULT_SUMMARY_PATH,
78+
help="Path to write the conversion/validation summary json.",
79+
)
80+
parser.add_argument(
81+
"--label",
82+
type=str,
83+
default="math_prm",
84+
help="Label written into the converted manifest.",
85+
)
86+
parser.add_argument(
87+
"--prompt-mode",
88+
type=str,
89+
choices=["question_only", "instruction"],
90+
default="question_only",
91+
help="How to build the LightRFT prompt from raw instruction.",
92+
)
93+
parser.add_argument(
94+
"--max-samples",
95+
type=int,
96+
default=None,
97+
help="Optional cap for the number of raw rows to process.",
98+
)
99+
parser.add_argument(
100+
"--smoke-samples",
101+
type=int,
102+
default=4,
103+
help="How many converted samples to use for PromptDatasetVL smoke validation.",
104+
)
105+
return parser.parse_args()
106+
107+
108+
def extract_prompt(raw_instruction: str, prompt_mode: str) -> tuple[str, bool]:
109+
text = (raw_instruction or "").strip()
110+
if not text:
111+
return "", False
112+
113+
if prompt_mode == "instruction":
114+
return text, False
115+
116+
marker = "Question:"
117+
idx = text.find(marker)
118+
if idx == -1:
119+
return text, True
120+
121+
question = text[idx + len(marker):].strip()
122+
# Some raw rows contain duplicated or malformed prefixes such as
123+
# "Question:estion: ...". Strip these repeatedly before returning.
124+
prefix_re = re.compile(r"^(?:(?:[Qq]uestion|[Qq]estion|[Ee]stion|[Uu]estion)\s*:)\s*")
125+
while True:
126+
cleaned = prefix_re.sub("", question)
127+
if cleaned == question:
128+
break
129+
question = cleaned.strip()
130+
if not question:
131+
return text, True
132+
return question, False
133+
134+
135+
def extract_reference(raw_output: str) -> tuple[str, bool]:
136+
text = (raw_output or "").strip()
137+
if not text:
138+
return "", False
139+
140+
marker = "†Answer:"
141+
idx = text.rfind(marker)
142+
if idx == -1:
143+
return text, True
144+
145+
answer = text[idx + len(marker):].strip()
146+
if not answer:
147+
return text, True
148+
return answer, False
149+
150+
151+
def build_record(
152+
raw: dict[str, Any],
153+
source_index: int,
154+
image_root: Path,
155+
prompt_mode: str,
156+
label: str,
157+
) -> tuple[dict[str, Any], dict[str, Any]]:
158+
image_url = str(raw.get("image_url", "")).strip()
159+
instruction = str(raw.get("instruction", "")).strip()
160+
output = str(raw.get("output", "")).strip()
161+
162+
prompt, used_prompt_fallback = extract_prompt(instruction, prompt_mode)
163+
reference, used_reference_fallback = extract_reference(output)
164+
165+
image_path = (image_root / image_url).resolve()
166+
prefix = image_url.split("/", 1)[0] if image_url else ""
167+
168+
record = {
169+
"data_source": "URSA-MATH/MMathCoT-1M",
170+
"prompt": prompt,
171+
"images": [str(image_path)],
172+
"reference": reference,
173+
"ground_truth": reference,
174+
"label": label,
175+
"reward_model": {
176+
"ground_truth": reference,
177+
},
178+
"extra_info": {
179+
"source_index": source_index,
180+
"raw_image_url": image_url,
181+
"image_prefix": prefix,
182+
"prompt_mode": prompt_mode,
183+
},
184+
}
185+
186+
meta = {
187+
"image_path_exists": image_path.exists(),
188+
"prompt_empty": prompt == "",
189+
"reference_empty": reference == "",
190+
"used_prompt_fallback": used_prompt_fallback,
191+
"used_reference_fallback": used_reference_fallback,
192+
"image_prefix": prefix,
193+
"image_path": str(image_path),
194+
}
195+
return record, meta
196+
197+
198+
def smoke_validate(converted_rows: list[dict[str, Any]], smoke_samples: int) -> dict[str, Any]:
199+
smoke_rows = converted_rows[: max(1, min(smoke_samples, len(converted_rows)))]
200+
strategy = SimpleNamespace(
201+
args=SimpleNamespace(
202+
input_key="prompt",
203+
images_key="images",
204+
reference_key="reference",
205+
label_key="label",
206+
apply_chat_template=False,
207+
system_prompt=None,
208+
)
209+
)
210+
dataset = PromptDatasetVL(
211+
smoke_rows,
212+
tokenizer=None,
213+
processor=None,
214+
max_length=0,
215+
strategy=strategy,
216+
)
217+
items = [dataset[i] for i in range(len(dataset))]
218+
prompts, images, references, labels = dataset.collate_fn(items)
219+
220+
first_prompt, first_images, first_reference, first_label = items[0]
221+
return {
222+
"sample_count": len(dataset),
223+
"first_item": {
224+
"prompt_preview": first_prompt[:240],
225+
"image_count": len(first_images) if isinstance(first_images, list) else 0,
226+
"first_image": first_images[0] if isinstance(first_images, list) and first_images else None,
227+
"reference": first_reference,
228+
"label": first_label,
229+
},
230+
"collate_sizes": {
231+
"prompts": len(prompts),
232+
"images": len(images),
233+
"references": len(references),
234+
"labels": len(labels),
235+
},
236+
}
237+
238+
239+
def main() -> None:
240+
args = parse_args()
241+
242+
input_path = Path(args.input_path).resolve()
243+
image_root = Path(args.image_root).resolve()
244+
output_path = Path(args.output_path).resolve()
245+
summary_path = Path(args.summary_path).resolve()
246+
247+
if not input_path.exists():
248+
raise FileNotFoundError(f"input jsonl not found: {input_path}")
249+
if not image_root.exists():
250+
raise FileNotFoundError(f"image root not found: {image_root}")
251+
252+
counters = Counter()
253+
prefix_counter: Counter[str] = Counter()
254+
smoke_rows: list[dict[str, Any]] = []
255+
256+
output_path.parent.mkdir(parents=True, exist_ok=True)
257+
with input_path.open("r", encoding="utf-8") as fp, output_path.open("w", encoding="utf-8") as out_fp:
258+
for source_index, line in enumerate(fp):
259+
if args.max_samples is not None and source_index >= args.max_samples:
260+
break
261+
262+
counters["rows_seen"] += 1
263+
raw = json.loads(line)
264+
record, meta = build_record(
265+
raw=raw,
266+
source_index=source_index,
267+
image_root=image_root,
268+
prompt_mode=args.prompt_mode,
269+
label=args.label,
270+
)
271+
272+
prefix_counter[meta["image_prefix"]] += 1
273+
if meta["used_prompt_fallback"]:
274+
counters["prompt_fallback_rows"] += 1
275+
if meta["used_reference_fallback"]:
276+
counters["reference_fallback_rows"] += 1
277+
if meta["prompt_empty"]:
278+
counters["empty_prompt_rows"] += 1
279+
if meta["reference_empty"]:
280+
counters["empty_reference_rows"] += 1
281+
if not meta["image_path_exists"]:
282+
raise FileNotFoundError(
283+
f"missing image for row {source_index}: {meta['image_path']}"
284+
)
285+
286+
out_fp.write(json.dumps(record, ensure_ascii=False) + "\n")
287+
counters["rows_written"] += 1
288+
if len(smoke_rows) < max(1, args.smoke_samples):
289+
smoke_rows.append(record)
290+
291+
if not smoke_rows:
292+
raise ValueError("No rows were converted. Check the input path and --max-samples.")
293+
294+
smoke = smoke_validate(smoke_rows, args.smoke_samples)
295+
296+
summary = {
297+
"input_path": str(input_path),
298+
"image_root": str(image_root),
299+
"output_path": str(output_path),
300+
"summary_path": str(summary_path),
301+
"label": args.label,
302+
"prompt_mode": args.prompt_mode,
303+
"rows_seen": counters["rows_seen"],
304+
"rows_written": counters["rows_written"],
305+
"prompt_fallback_rows": counters["prompt_fallback_rows"],
306+
"reference_fallback_rows": counters["reference_fallback_rows"],
307+
"empty_prompt_rows": counters["empty_prompt_rows"],
308+
"empty_reference_rows": counters["empty_reference_rows"],
309+
"image_prefix_counts": dict(prefix_counter),
310+
"images_per_sample_counts": {"1": counters["rows_written"]},
311+
"smoke_validation": smoke,
312+
}
313+
314+
summary_path.parent.mkdir(parents=True, exist_ok=True)
315+
summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
316+
317+
print(json.dumps(summary, ensure_ascii=False, indent=2))
318+
319+
320+
if __name__ == "__main__":
321+
main()

0 commit comments

Comments
 (0)