Skip to content

Commit 418ea07

Browse files
committed
refactor: Replace _to_number function with gold_from_gsm8k in dataset_handler and update load_generation_model parameters in core.py
1 parent 9dc83f9 commit 418ea07

3 files changed

Lines changed: 42 additions & 73 deletions

File tree

src/dataset_handler.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from fractions import Fraction
44
import random
55

6+
from answer_utils import _to_number
7+
68

79
def load_hard_dataset(name, split, n, seed):
810
if name == "competition_math":
@@ -32,9 +34,7 @@ def load_hard_dataset(name, split, n, seed):
3234
elif name == "gsm8k":
3335
ds = load_dataset("gsm8k", "main")
3436
data = ds[split].shuffle(seed=seed).select(range(n)) if n else ds[split]
35-
gold_fn = lambda ex: _to_number(
36-
re.findall(r"####\s*([-\$]?\s*\d[\d,]*(?:\.\d+)?(?:\s+\w+)?)", ex["answer"])[0]
37-
)
37+
gold_fn = lambda ex: gold_from_gsm8k(ex["answer"])
3838
q_fn = lambda ex: ex["question"]
3939

4040
elif name == "svamp":
@@ -276,26 +276,6 @@ def _norm(s):
276276
return re.sub(r"\s+", " ", str(s)).strip().lower()
277277

278278

279-
def _to_number(s):
280-
if s is None:
281-
return None
282-
s = s.strip()
283-
s = s.replace(",", "")
284-
s = re.sub(r"^\$", "", s)
285-
s = re.sub(r"\s+(dollars?|tickets?|units?|boxes?|people|students?)$", "", s, flags=re.I)
286-
if re.fullmatch(r"-?\d+/\d+", s):
287-
return float(Fraction(s))
288-
if s.endswith("%"):
289-
try:
290-
return float(s[:-1]) / 100.0
291-
except:
292-
return None
293-
try:
294-
return float(s)
295-
except:
296-
return None
297-
298-
299279
def _gpqa_perm(ex):
300280
"""
301281
Deterministically shuffle the 4 answer options so that:

src/experiments/core.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,11 @@ def load_generation_model(
560560
base_model_id: str,
561561
tokenizer: Optional[AutoTokenizer] = None,
562562
torch_dtype: torch.dtype = torch.float16,
563+
*,
564+
device_map: Optional[str] = "auto",
565+
load_in_4bit: bool = False,
566+
attn_impl: str = "sdpa",
567+
compile_model: bool = False,
563568
):
564569
"""
565570
Loads a Causal LM for generation. If `tokenizer` provided, embeddings are resized accordingly.
@@ -568,18 +573,42 @@ def load_generation_model(
568573
tokenizer = AutoTokenizer.from_pretrained(
569574
base_model_id, use_fast=True, trust_remote_code=True
570575
)
571-
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
572-
tokenizer.pad_token = tokenizer.eos_token
576+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
577+
tokenizer.pad_token = tokenizer.eos_token
578+
579+
kwargs: Dict[str, Any] = {
580+
"torch_dtype": torch_dtype,
581+
"trust_remote_code": True,
582+
}
583+
if device_map is not None:
584+
kwargs["device_map"] = device_map
585+
if attn_impl:
586+
kwargs["attn_implementation"] = attn_impl
587+
if load_in_4bit:
588+
compute_dtype = (
589+
torch_dtype if torch_dtype in (torch.float16, torch.bfloat16) else torch.float16
590+
)
591+
kwargs["quantization_config"] = BitsAndBytesConfig(
592+
load_in_4bit=True,
593+
bnb_4bit_compute_dtype=compute_dtype,
594+
bnb_4bit_use_double_quant=True,
595+
)
596+
kwargs["torch_dtype"] = compute_dtype
573597

574598
model = AutoModelForCausalLM.from_pretrained(
575599
base_model_id,
576-
torch_dtype=torch_dtype,
577-
device_map="auto",
578-
trust_remote_code=True,
600+
**kwargs,
579601
)
580602
if len(tokenizer) != model.get_input_embeddings().weight.size(0):
581603
model.resize_token_embeddings(len(tokenizer))
582604
model.eval()
605+
606+
if compile_model:
607+
try:
608+
model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
609+
except Exception as exc:
610+
print(f"[warn] torch.compile failed: {exc}. Continuing without compile().")
611+
583612
return model, tokenizer
584613

585614

src/run_mcts.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import warnings
99

1010
import torch
11-
from transformers import AutoTokenizer, AutoModelForCausalLM
11+
from transformers import AutoTokenizer
1212
from transformers.utils import logging as hf_logging
1313
import yaml
1414

1515
from answer_utils import is_correct
1616
from mas import build_mas_from_specs
1717
from mcts import Node, MAS_MCTS
1818
from dataset_handler import load_hard_dataset
19+
from experiments.core import load_generation_model
1920
from show_tree import build_graph, draw_tree
2021

2122
import ray
@@ -30,47 +31,6 @@
3031
)
3132

3233

33-
def load_policy(
34-
model_id: str,
35-
device_map: str,
36-
load_in_4bit: bool = False,
37-
attn_impl: str = "sdpa",
38-
compile_model: bool = True,
39-
):
40-
"""
41-
Loads tokenizer and model with either fp16/fp32 or 4-bit quantization.
42-
device_map is now provided by argparse for flexibility.
43-
"""
44-
kwargs = dict(device_map=device_map)
45-
kwargs["attn_implementation"] = attn_impl
46-
if load_in_4bit:
47-
from transformers import BitsAndBytesConfig
48-
49-
kwargs["quantization_config"] = BitsAndBytesConfig(
50-
load_in_4bit=True,
51-
bnb_4bit_compute_dtype=torch.float16,
52-
bnb_4bit_use_double_quant=True,
53-
)
54-
kwargs["torch_dtype"] = torch.float16
55-
else:
56-
kwargs["torch_dtype"] = torch.float16 if torch.cuda.is_available() else torch.float32
57-
58-
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
59-
mdl = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
60-
61-
# tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, local_files_only=True)
62-
# mdl = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True, **kwargs)
63-
64-
mdl.eval()
65-
if compile_model:
66-
try:
67-
mdl = torch.compile(mdl, mode="reduce-overhead", fullgraph=False)
68-
except Exception as e:
69-
print(f"[warn] torch.compile failed: {e}. Continuing without compile().")
70-
71-
return tok, mdl
72-
73-
7434
def node_to_dict(node: Node, max_children: int = 8) -> Dict[str, Any]:
7535
return {
7636
"steps": node.steps,
@@ -187,7 +147,7 @@ def __init__(
187147
self.mdl = None
188148
else:
189149
self.client = None
190-
self.tok, self.mdl = load_policy(
150+
self.mdl, self.tok = load_generation_model(
191151
model_id,
192152
device_map=device_map,
193153
load_in_4bit=load_in_4bit,
@@ -428,9 +388,9 @@ def main():
428388
openai_model=args.model_id,
429389
)
430390
else:
431-
tok, mdl = load_policy(
391+
mdl, tok = load_generation_model(
432392
args.model_id,
433-
args.device_map,
393+
device_map=args.device_map,
434394
load_in_4bit=args.load_in_4bit,
435395
attn_impl=args.attn_impl,
436396
compile_model=not args.no_compile,

0 commit comments

Comments
 (0)