Skip to content

Commit 952fb88

Browse files
committed
add convert lm script
1 parent 154160c commit 952fb88

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""small code to generate a sequence classifier from a causal language model
2+
Copyright Michael Feil, 2025, MIT License
3+
"""
4+
5+
import torch
6+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
7+
8+
@torch.no_grad()
9+
def convert_to_sequence_classifier(model_name: str, slice_single_token_list: None | set[int] = None) -> torch.nn.Module:
10+
"""Convert a causal language model to a sequence classifier. build for mixedbread-ai/mxbai-rerank-base-v2
11+
12+
Example usage:
13+
The model classifies a static prompt (prefill) and uses the next token distribution of no and yes to classify the prompt.
14+
https://github.com/mixedbread-ai/mxbai-rerank/blob/ca0c55d03770d9bb183ca759850bf7cdfbcc9f50/mxbai_rerank/mxbai_rerank_v2.py#L34
15+
a good example is:
16+
"You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).\nRelevance:"
17+
thus, we need to get the next token distribution of no and yes to classify the prompt.
18+
the token ids of no and yes are [15, 16] == [tokenizer("0").input_ids, tokenizer("1").input_ids]
19+
Args:
20+
model_name (str): model name for the causal language model / AutoModelForCausalLM
21+
slice_single_token_list (None | set[int], optional): slice the lm_head to a subset of tokens. Defaults to None, which will give vocab_size outputs.
22+
23+
Returns:
24+
AutoModelForSequenceClassification: model classifier
25+
"""
26+
model_lm = AutoModelForCausalLM.from_pretrained(model_name)
27+
model_lm.model = None # free up memory
28+
assert model_lm.lm_head.bias is None
29+
model_classifier = AutoModelForSequenceClassification.from_pretrained(model_name)
30+
tokenizer = AutoTokenizer.from_pretrained(model_name)
31+
def get_input_ids(x):
32+
return tokenizer(x, return_tensors=None, add_special_tokens=False)["input_ids"]
33+
# tokenizer select tokens
34+
num_total_tokens = model_lm.lm_head.out_features
35+
if slice_single_token_list is not None:
36+
slice_single_token_list = list(sorted(set(slice_single_token_list)))
37+
assert max(slice_single_token_list) <= num_total_tokens
38+
assert min(slice_single_token_list) >= 0
39+
40+
tokens = tokenizer.convert_ids_to_tokens(range(num_total_tokens))
41+
if slice_single_token_list is not None:
42+
tokens_selected = [tokens[i] for i in slice_single_token_list]
43+
# slice the score head and build a linear on the fly
44+
new_score = model_lm.lm_head.weight[slice_single_token_list]
45+
else:
46+
tokens_selected = tokens
47+
new_score = model_lm.lm_head.weight
48+
num_tokens = len(tokens_selected)
49+
# add classifier head from lm head
50+
model_classifier.config.num_labels = num_tokens
51+
linear = torch.nn.Linear(model_lm.lm_head.in_features, num_tokens, bias=False)
52+
linear.weight.data = new_score
53+
model_classifier.score = linear
54+
# add id2label and label2id
55+
model_classifier.config.id2label = {
56+
num: label for num, label in enumerate(tokens_selected)
57+
}
58+
model_classifier.config.label2id = {
59+
label: num for num, label in enumerate(tokens_selected)
60+
}
61+
return model_classifier
62+
63+
64+
def test_mxbai_rerank_v2():
65+
no_id_yes_id = [15, 16] # no, yes, == [tokenizer("0").input_ids, tokenizer("1").input_ids]
66+
67+
with torch.no_grad():
68+
model_cls = convert_to_sequence_classifier("mixedbread-ai/mxbai-rerank-base-v2", no_id_yes_id)
69+
model_lm = AutoModelForCausalLM.from_pretrained("mixedbread-ai/mxbai-rerank-base-v2")
70+
71+
tokenizer = AutoTokenizer.from_pretrained("mixedbread-ai/mxbai-rerank-base-v2")
72+
73+
example = {
74+
"instruction": "You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).",
75+
"query": ["What is the capital of France?"],
76+
"document": ["The capital of France is Paris.", "Who is the president of France?"],
77+
}
78+
examples_formatted = [
79+
f"{example['instruction']}\n{example['query'][0]}\n{example['document'][0]}",
80+
f"{example['instruction']}\n{example['query'][0]}\n{example['document'][1]}",
81+
]
82+
# create pytorch tensors
83+
for example_formatted in examples_formatted:
84+
tokenized = tokenizer(
85+
example_formatted,
86+
return_tensors="pt",
87+
truncation=True,
88+
)
89+
# forward pass
90+
output = model_cls(**tokenized).logits
91+
output_lm = model_lm(**tokenized).logits[0,-1,no_id_yes_id]
92+
print(output_lm)
93+
print(output)
94+
assert torch.allclose(output_lm, output)
95+
print("done")
96+
97+
if __name__ == "__main__":
98+
test_mxbai_rerank_v2()

0 commit comments

Comments
 (0)