Skip to content

Commit 5847dc2

Browse files
authored
Evaluate sampling report using RM (#2190)
Framework to evaluate sampling report results using any Reward model. Proposed in #1908
1 parent d631e93 commit 5847dc2

4 files changed

Lines changed: 179 additions & 0 deletions

File tree

model/model_eval/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Evaluate Sampling Reports using Reward Model
2+
3+
### Requirements
4+
5+
- cd model/
6+
- `pip install -e . `
7+
- cd oasst-data
8+
- `pip install -e .`
9+
10+
### Run
11+
12+
```
13+
python model/model_eval/sampling_score.py --model andreaskoepf/oasst-rm-1-pythia-1b --data_path model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-12b-deduped-sft_sampling.json
14+
```
15+
16+
## Example results
17+
18+
```
19+
{'beam5': -1.592665433883667, 'greedy': -1.592665433883667, 'k50': -1.592665433883667, 'magic_numbers': -1.592665433883667, 'mean_reward': '-1.5926653'}
20+
```

model/model_eval/__init__.py

Whitespace-only changes.

model/model_eval/eval_datasets.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
from model_training.custom_datasets.ranking_collator import RankingDataCollator
3+
from torch.utils.data import DataLoader, Dataset
4+
5+
6+
def get_sampling_dataloader(data, tokenizer, max_length, batch_size):
7+
collate_fn = SamplingDataCollator(tokenizer, max_length=max_length)
8+
dataset = SamplingDataset(data)
9+
return DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size)
10+
11+
12+
class SamplingDataCollator(RankingDataCollator):
13+
def __call__(self, examples):
14+
flat_tokenized = []
15+
sampling_ids = []
16+
for example in examples:
17+
prefix, reply, sampling = example
18+
sampling_ids.append(sampling)
19+
tokenized = self.process_one((prefix, reply))
20+
flat_tokenized.extend(tokenized)
21+
22+
batch = self.tokenizer.pad(
23+
flat_tokenized,
24+
padding=self.padding,
25+
max_length=self.max_length,
26+
pad_to_multiple_of=self.pad_to_multiple_of,
27+
return_tensors="pt",
28+
)
29+
30+
if "token_type_ids" in batch:
31+
batch.pop("token_type_ids")
32+
33+
batch["sampling"] = torch.tensor(sampling_ids)
34+
return batch
35+
36+
37+
class SamplingDataset(Dataset):
38+
39+
"""
40+
Dataset for loading sampling reports
41+
"""
42+
43+
def __init__(self, dataset):
44+
super().__init__()
45+
46+
self.dataset = []
47+
sampling_list = []
48+
for data in dataset["prompts"][:4]:
49+
prompt = data["prompt"]
50+
for result in data["results"]:
51+
sampling = result["sampling_config"]
52+
for output in result["outputs"]:
53+
self.dataset.append((prompt, output, sampling))
54+
if sampling not in sampling_list:
55+
sampling_list.append(sampling)
56+
57+
self.label2id = self.get_label2id(sampling_list)
58+
59+
def get_label2id(self, sampling_list):
60+
return {v: k for k, v in enumerate(sampling_list)}
61+
62+
def __len__(self):
63+
return len(self.dataset)
64+
65+
def __getitem__(self, idx):
66+
prefix, reply, sampling = self.dataset[idx]
67+
sampling = self.label2id[sampling]
68+
69+
return ([prefix], [reply], sampling)

model/model_eval/sampling_score.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
import model_training.models.reward_model # noqa: F401 (registers reward model for AutoModel loading)
6+
import numpy as np
7+
import pandas as pd
8+
import torch
9+
from eval_datasets import get_sampling_dataloader
10+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
11+
12+
13+
def load_sampling_data(path):
14+
"""
15+
Load sampling data and ensure appropriate keys are present.
16+
"""
17+
18+
if os.path.exists(path):
19+
data = json.load(open(path))
20+
else:
21+
raise FileNotFoundError(f"Sampling data {path} not found")
22+
23+
if "prompts" not in data.keys():
24+
raise KeyError("sampling data should contain prompts key")
25+
26+
keys = set(data["prompts"][0].keys())
27+
required_keys = set(["prompt", "results"])
28+
keys = keys.intersection(required_keys)
29+
if keys != required_keys:
30+
raise KeyError(f"Missing keys {required_keys - keys} ")
31+
32+
return data
33+
34+
35+
def batch_inference(model, dataloader):
36+
"""
37+
Batch inference
38+
"""
39+
40+
scores, sampling = [], []
41+
device = model.device
42+
for i, data in enumerate(dataloader):
43+
sampling.append(data.pop("sampling").cpu().detach().numpy())
44+
data = {k: v.squeeze().to(device) for k, v in data.items()}
45+
pred = model(**data).logits[:, 0].cpu().detach().numpy()
46+
scores.append(pred)
47+
48+
return np.hstack(sampling), np.hstack(scores)
49+
50+
51+
if __name__ == "__main__":
52+
parser = argparse.ArgumentParser(description="")
53+
parser.add_argument("--data_path", type=str, help="Path of the sampling data file")
54+
parser.add_argument("--model", type=str, help="Path or url of the model file")
55+
parser.add_argument("--max_length", type=int, help="max length of input")
56+
parser.add_argument("--batch_size", type=int, help="device", default=4)
57+
parser.add_argument("--device", type=str, help="device", default="cpu")
58+
parser.add_argument("--save", type=bool, help="whether to save the results", default=True)
59+
60+
args = parser.parse_args().__dict__
61+
if args.get("device") != "cpu":
62+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
63+
else:
64+
device = torch.device("cpu")
65+
66+
data = load_sampling_data(args.get("data_path"))
67+
68+
model_name = args.get("model")
69+
70+
tokenizer = AutoTokenizer.from_pretrained(model_name)
71+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
72+
model.eval()
73+
model.to(device)
74+
max_length = args.get("max_length")
75+
dataloader = get_sampling_dataloader(data, tokenizer, max_length, args.get("batch_size"))
76+
sampling, scores = batch_inference(model, dataloader)
77+
78+
df = pd.DataFrame({"sampling": sampling, "score": scores})
79+
id2label = {v: k for k, v in dataloader.dataset.label2id.items()}
80+
df["sampling"] = df["sampling"].map(id2label)
81+
results = df.groupby("sampling")["score"].mean().to_dict()
82+
results["mean_reward"] = str(df["score"].mean())
83+
print("RESULTS: ", results)
84+
85+
results = {"model_name": data["model_name"], "results": results, "reward_model": args.get("model")}
86+
name = "-".join(data["model_name"].split("/"))
87+
88+
if args.get("save"):
89+
with open(f"{name}.json", "w") as file:
90+
json.dump(results, file, indent=4)

0 commit comments

Comments
 (0)