|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | + |
| 5 | +import model_training.models.reward_model # noqa: F401 (registers reward model for AutoModel loading) |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | +import torch |
| 9 | +from eval_datasets import get_sampling_dataloader |
| 10 | +from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| 11 | + |
| 12 | + |
| 13 | +def load_sampling_data(path): |
| 14 | + """ |
| 15 | + Load sampling data and ensure appropriate keys are present. |
| 16 | + """ |
| 17 | + |
| 18 | + if os.path.exists(path): |
| 19 | + data = json.load(open(path)) |
| 20 | + else: |
| 21 | + raise FileNotFoundError(f"Sampling data {path} not found") |
| 22 | + |
| 23 | + if "prompts" not in data.keys(): |
| 24 | + raise KeyError("sampling data should contain prompts key") |
| 25 | + |
| 26 | + keys = set(data["prompts"][0].keys()) |
| 27 | + required_keys = set(["prompt", "results"]) |
| 28 | + keys = keys.intersection(required_keys) |
| 29 | + if keys != required_keys: |
| 30 | + raise KeyError(f"Missing keys {required_keys - keys} ") |
| 31 | + |
| 32 | + return data |
| 33 | + |
| 34 | + |
| 35 | +def batch_inference(model, dataloader): |
| 36 | + """ |
| 37 | + Batch inference |
| 38 | + """ |
| 39 | + |
| 40 | + scores, sampling = [], [] |
| 41 | + device = model.device |
| 42 | + for i, data in enumerate(dataloader): |
| 43 | + sampling.append(data.pop("sampling").cpu().detach().numpy()) |
| 44 | + data = {k: v.squeeze().to(device) for k, v in data.items()} |
| 45 | + pred = model(**data).logits[:, 0].cpu().detach().numpy() |
| 46 | + scores.append(pred) |
| 47 | + |
| 48 | + return np.hstack(sampling), np.hstack(scores) |
| 49 | + |
| 50 | + |
| 51 | +if __name__ == "__main__": |
| 52 | + parser = argparse.ArgumentParser(description="") |
| 53 | + parser.add_argument("--data_path", type=str, help="Path of the sampling data file") |
| 54 | + parser.add_argument("--model", type=str, help="Path or url of the model file") |
| 55 | + parser.add_argument("--max_length", type=int, help="max length of input") |
| 56 | + parser.add_argument("--batch_size", type=int, help="device", default=4) |
| 57 | + parser.add_argument("--device", type=str, help="device", default="cpu") |
| 58 | + parser.add_argument("--save", type=bool, help="whether to save the results", default=True) |
| 59 | + |
| 60 | + args = parser.parse_args().__dict__ |
| 61 | + if args.get("device") != "cpu": |
| 62 | + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| 63 | + else: |
| 64 | + device = torch.device("cpu") |
| 65 | + |
| 66 | + data = load_sampling_data(args.get("data_path")) |
| 67 | + |
| 68 | + model_name = args.get("model") |
| 69 | + |
| 70 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 71 | + model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| 72 | + model.eval() |
| 73 | + model.to(device) |
| 74 | + max_length = args.get("max_length") |
| 75 | + dataloader = get_sampling_dataloader(data, tokenizer, max_length, args.get("batch_size")) |
| 76 | + sampling, scores = batch_inference(model, dataloader) |
| 77 | + |
| 78 | + df = pd.DataFrame({"sampling": sampling, "score": scores}) |
| 79 | + id2label = {v: k for k, v in dataloader.dataset.label2id.items()} |
| 80 | + df["sampling"] = df["sampling"].map(id2label) |
| 81 | + results = df.groupby("sampling")["score"].mean().to_dict() |
| 82 | + results["mean_reward"] = str(df["score"].mean()) |
| 83 | + print("RESULTS: ", results) |
| 84 | + |
| 85 | + results = {"model_name": data["model_name"], "results": results, "reward_model": args.get("model")} |
| 86 | + name = "-".join(data["model_name"].split("/")) |
| 87 | + |
| 88 | + if args.get("save"): |
| 89 | + with open(f"{name}.json", "w") as file: |
| 90 | + json.dump(results, file, indent=4) |
0 commit comments