Skip to content

Commit 0b86740

Browse files
committed
fix a bug in mc evaluation
1 parent b6fa1ff commit 0b86740

File tree

4 files changed

+69
-35
lines changed

4 files changed

+69
-35
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
<!-- # Dromedary -->
1010

11+
### NeurIPS 2023 (Spotlight)
12+
1113
## Principle-Driven Self-Alignment of Language Models from Scratch with Minimal Human Supervision
1214

1315
</div>

llama_dromedary/llama_dromedary/generation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def score(
449449
params.max_seq_len,
450450
params.max_shared_seq_len,
451451
)
452-
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
452+
tokens = torch.full((bsz, total_len), self.tokenizer.eos_id).cuda().long()
453453

454454
for i, (prompt_t, target_t) in enumerate(zip(prompt_tokens, target_tokens)):
455455
tokens[i, : len(prompt_t)] = torch.tensor(prompt_t).long()

mc_evaluation/evaluate_hhh_eval.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import tqdm
1111

12-
from llama_dromedary.utils import setup_model_parallel, sync_model_parallel, load_model, llama_scoring
12+
from llama_dromedary import Llama
1313

1414

1515
def measure_multiple_choice_grade(samples):
@@ -24,7 +24,9 @@ def measure_multiple_choice_grade(samples):
2424
def argmax(array):
2525
"""argmax with deterministic pseudorandom tie breaking."""
2626
max_indices = np.arange(len(array))[array == np.max(array)]
27-
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(),16) % len(max_indices)
27+
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(
28+
max_indices
29+
)
2830
return max_indices[idx]
2931

3032
for sample in samples:
@@ -64,21 +66,19 @@ def main(
6466
meta_prompt = "".join(data)
6567
meta_prompt = meta_prompt.strip()
6668

67-
global_rank, world_size = setup_model_parallel()
68-
if global_rank > 0:
69-
sys.stdout = open(os.devnull, "w")
70-
7169
t0 = time.time()
72-
generator = load_model(
73-
ckpt_dir, tokenizer_path, global_rank, world_size,
74-
max_seq_len, max_batch_size, max_shared_seq_len,
75-
disable_cache=True,
70+
generator = Llama.build(
71+
ckpt_dir=ckpt_dir,
72+
tokenizer_path=tokenizer_path,
73+
max_seq_len=max_seq_len,
74+
max_batch_size=max_batch_size,
75+
max_shared_seq_len=max_shared_seq_len,
7676
)
7777
t1 = time.time()
78-
loading_time = t1-t0
78+
loading_time = t1 - t0
7979
print("Model loading time on %d: " % group_size, loading_time)
8080

81-
sync_model_parallel()
81+
global_rank = int(os.environ.get("RANK", 0))
8282
tasks = ["harmless", "helpful", "honest", "other"]
8383

8484
all_predictions = []
@@ -93,7 +93,15 @@ def main(
9393
# only show tqdm at rank 0
9494
for example in tqdm.tqdm(examples, disable=global_rank > 0):
9595
targets = list(example["target_scores"].keys())
96-
log_prob = get_log_prob(generator, example, targets, meta_prompt, generate_prompt_fn, temperature, max_seq_len)
96+
log_prob = get_log_prob(
97+
generator,
98+
example,
99+
targets,
100+
meta_prompt,
101+
generate_prompt_fn,
102+
temperature,
103+
max_seq_len,
104+
)
97105
full_pred = {}
98106
full_pred["choice"] = targets
99107
full_pred["log_prob"] = log_prob
@@ -108,7 +116,15 @@ def main(
108116
print(f"Overall HHH Eval MC grade over {len(all_predictions)} examples: {mc_grad}")
109117

110118

111-
def get_log_prob(generator, example, targets, meta_prompt, generate_prompt_fn, temperature, max_seq_len):
119+
def get_log_prob(
120+
generator,
121+
example,
122+
targets,
123+
meta_prompt,
124+
generate_prompt_fn,
125+
temperature,
126+
max_seq_len,
127+
):
112128
answer_candidates = targets
113129

114130
def truncate_seq(seq, prefix="", suffix=""):
@@ -121,7 +137,7 @@ def truncate_seq(seq, prefix="", suffix=""):
121137
tokenized_inputs = tokenized_inputs[-safe_seq_len:]
122138
seq = generator.tokenizer.decode(tokenized_inputs).strip()
123139
if flag:
124-
seq= prefix + seq + suffix
140+
seq = prefix + seq + suffix
125141
return seq
126142

127143
inputs = truncate_seq(example["input"], prefix="... ")
@@ -149,7 +165,7 @@ def truncate_seq(seq, prefix="", suffix=""):
149165
all_prompts = [prompt_1, prompt_1, prompt_2, prompt_2]
150166
all_targets = [" A", " B", " A", " B"]
151167

152-
log_prob = llama_scoring(generator, all_prompts, all_targets, temperature)
168+
log_prob = generator.score(generator, all_prompts, all_targets, temperature)
153169

154170
aggregate_log_prob = [log_prob[0] + log_prob[3], log_prob[1] + log_prob[2]]
155171
return aggregate_log_prob

mc_evaluation/evaluate_truthfulqa_mc.py

+34-18
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
import tqdm
99

10-
from llama_dromedary.utils import setup_model_parallel, sync_model_parallel, load_model, llama_scoring
10+
from llama_dromedary import Llama
1111

1212
from datasets import load_dataset
1313

@@ -24,7 +24,9 @@ def measure_multiple_choice_grade(samples):
2424
def argmax(array):
2525
"""argmax with deterministic pseudorandom tie breaking."""
2626
max_indices = np.arange(len(array))[array == np.max(array)]
27-
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(),16) % len(max_indices)
27+
idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(
28+
max_indices
29+
)
2830
return max_indices[idx]
2931

3032
for sample in samples:
@@ -61,18 +63,16 @@ def main(
6163
meta_prompt = "".join(data)
6264
meta_prompt = meta_prompt.strip()
6365

64-
global_rank, world_size = setup_model_parallel()
65-
if global_rank > 0:
66-
sys.stdout = open(os.devnull, "w")
67-
6866
t0 = time.time()
69-
generator = load_model(
70-
ckpt_dir, tokenizer_path, global_rank, world_size,
71-
max_seq_len, max_batch_size, max_shared_seq_len,
72-
disable_cache=True,
67+
generator = Llama.build(
68+
ckpt_dir=ckpt_dir,
69+
tokenizer_path=tokenizer_path,
70+
max_seq_len=max_seq_len,
71+
max_batch_size=max_batch_size,
72+
max_shared_seq_len=max_shared_seq_len,
7373
)
7474
t1 = time.time()
75-
loading_time = t1-t0
75+
loading_time = t1 - t0
7676
print("Model loading time on %d: " % group_size, loading_time)
7777

7878
dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
@@ -83,20 +83,29 @@ def main(
8383
example = {}
8484
example["input"] = data_point["question"]
8585
example["target_scores"] = {}
86-
mc1_choices = data_point["mc1_targets"]['choices']
87-
mc1_scores = data_point["mc1_targets"]['labels']
86+
mc1_choices = data_point["mc1_targets"]["choices"]
87+
mc1_scores = data_point["mc1_targets"]["labels"]
8888

8989
for choice, score in zip(mc1_choices, mc1_scores):
9090
example["target_scores"][choice] = score
9191
examples.append(example)
9292

9393
predictions = []
9494

95-
sync_model_parallel()
95+
global_rank = int(os.environ.get("RANK", 0))
96+
9697
# only show tqdm at rank 0
9798
for example in tqdm.tqdm(examples, disable=global_rank > 0):
9899
targets = list(example["target_scores"].keys())
99-
log_prob = get_log_prob(generator, example, targets, meta_prompt, generate_prompt_fn, temperature, max_seq_len)
100+
log_prob = get_log_prob(
101+
generator,
102+
example,
103+
targets,
104+
meta_prompt,
105+
generate_prompt_fn,
106+
temperature,
107+
max_seq_len,
108+
)
100109
full_pred = {}
101110
full_pred["choice"] = targets
102111
full_pred["log_prob"] = log_prob
@@ -107,12 +116,19 @@ def main(
107116
print(f"MC1 grade: {mc_grad}")
108117

109118

110-
def get_log_prob(generator, example, targets, meta_prompt, generate_prompt_fn, temperature, max_seq_len):
119+
def get_log_prob(
120+
generator,
121+
example,
122+
targets,
123+
meta_prompt,
124+
generate_prompt_fn,
125+
temperature,
126+
max_seq_len,
127+
):
111128
del max_seq_len
112129
answer_candidates = targets
113130
inputs = example["input"]
114131

115-
116132
input_story = f"""Question: {inputs}
117133
118134
Answer: {'{}'} (true or false)
@@ -139,7 +155,7 @@ def get_log_prob(generator, example, targets, meta_prompt, generate_prompt_fn, t
139155
all_prompts.append(prompt)
140156
all_targets.append(" false")
141157

142-
log_prob = llama_scoring(generator, all_prompts, all_targets, temperature)
158+
log_prob = generator.score(all_prompts, all_targets, temperature)
143159
true_log_prob = []
144160
for i in range(0, len(answer_candidates), 2):
145161
true_log_prob.append(log_prob[i] - log_prob[i + 1])

0 commit comments

Comments
 (0)