|
| 1 | +"""small code to generate a sequence classifier from a causal language model |
| 2 | +Copyright Michael Feil, 2025, MIT License |
| 3 | +""" |
| 4 | + |
| 5 | +import torch |
| 6 | +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification |
| 7 | + |
| 8 | +@torch.no_grad() |
| 9 | +def convert_to_sequence_classifier(model_name: str, slice_single_token_list: None | set[int] = None) -> torch.nn.Module: |
| 10 | + """Convert a causal language model to a sequence classifier. build for mixedbread-ai/mxbai-rerank-base-v2 |
| 11 | + |
| 12 | + Example usage: |
| 13 | + The model classifies a static prompt (prefill) and uses the next token distribution of no and yes to classify the prompt. |
| 14 | + https://github.com/mixedbread-ai/mxbai-rerank/blob/ca0c55d03770d9bb183ca759850bf7cdfbcc9f50/mxbai_rerank/mxbai_rerank_v2.py#L34 |
| 15 | + a good example is: |
| 16 | + "You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).\nRelevance:" |
| 17 | + thus, we need to get the next token distribution of no and yes to classify the prompt. |
| 18 | + the token ids of no and yes are [15, 16] == [tokenizer("0").input_ids, tokenizer("1").input_ids] |
| 19 | + Args: |
| 20 | + model_name (str): model name for the causal language model / AutoModelForCausalLM |
| 21 | + slice_single_token_list (None | set[int], optional): slice the lm_head to a subset of tokens. Defaults to None, which will give vocab_size outputs. |
| 22 | +
|
| 23 | + Returns: |
| 24 | + AutoModelForSequenceClassification: model classifier |
| 25 | + """ |
| 26 | + model_lm = AutoModelForCausalLM.from_pretrained(model_name) |
| 27 | + model_lm.model = None # free up memory |
| 28 | + assert model_lm.lm_head.bias is None |
| 29 | + model_classifier = AutoModelForSequenceClassification.from_pretrained(model_name) |
| 30 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 31 | + def get_input_ids(x): |
| 32 | + return tokenizer(x, return_tensors=None, add_special_tokens=False)["input_ids"] |
| 33 | + # tokenizer select tokens |
| 34 | + num_total_tokens = model_lm.lm_head.out_features |
| 35 | + if slice_single_token_list is not None: |
| 36 | + slice_single_token_list = list(sorted(set(slice_single_token_list))) |
| 37 | + assert max(slice_single_token_list) <= num_total_tokens |
| 38 | + assert min(slice_single_token_list) >= 0 |
| 39 | + |
| 40 | + tokens = tokenizer.convert_ids_to_tokens(range(num_total_tokens)) |
| 41 | + if slice_single_token_list is not None: |
| 42 | + tokens_selected = [tokens[i] for i in slice_single_token_list] |
| 43 | + # slice the score head and build a linear on the fly |
| 44 | + new_score = model_lm.lm_head.weight[slice_single_token_list] |
| 45 | + else: |
| 46 | + tokens_selected = tokens |
| 47 | + new_score = model_lm.lm_head.weight |
| 48 | + num_tokens = len(tokens_selected) |
| 49 | + # add classifier head from lm head |
| 50 | + model_classifier.config.num_labels = num_tokens |
| 51 | + linear = torch.nn.Linear(model_lm.lm_head.in_features, num_tokens, bias=False) |
| 52 | + linear.weight.data = new_score |
| 53 | + model_classifier.score = linear |
| 54 | + # add id2label and label2id |
| 55 | + model_classifier.config.id2label = { |
| 56 | + num: label for num, label in enumerate(tokens_selected) |
| 57 | + } |
| 58 | + model_classifier.config.label2id = { |
| 59 | + label: num for num, label in enumerate(tokens_selected) |
| 60 | + } |
| 61 | + return model_classifier |
| 62 | + |
| 63 | + |
| 64 | +def test_mxbai_rerank_v2(): |
| 65 | + no_id_yes_id = [15, 16] # no, yes, == [tokenizer("0").input_ids, tokenizer("1").input_ids] |
| 66 | + |
| 67 | + with torch.no_grad(): |
| 68 | + model_cls = convert_to_sequence_classifier("mixedbread-ai/mxbai-rerank-base-v2", no_id_yes_id) |
| 69 | + model_lm = AutoModelForCausalLM.from_pretrained("mixedbread-ai/mxbai-rerank-base-v2") |
| 70 | + |
| 71 | + tokenizer = AutoTokenizer.from_pretrained("mixedbread-ai/mxbai-rerank-base-v2") |
| 72 | + |
| 73 | + example = { |
| 74 | + "instruction": "You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).", |
| 75 | + "query": ["What is the capital of France?"], |
| 76 | + "document": ["The capital of France is Paris.", "Who is the president of France?"], |
| 77 | + } |
| 78 | + examples_formatted = [ |
| 79 | + f"{example['instruction']}\n{example['query'][0]}\n{example['document'][0]}", |
| 80 | + f"{example['instruction']}\n{example['query'][0]}\n{example['document'][1]}", |
| 81 | + ] |
| 82 | + # create pytorch tensors |
| 83 | + for example_formatted in examples_formatted: |
| 84 | + tokenized = tokenizer( |
| 85 | + example_formatted, |
| 86 | + return_tensors="pt", |
| 87 | + truncation=True, |
| 88 | + ) |
| 89 | + # forward pass |
| 90 | + output = model_cls(**tokenized).logits |
| 91 | + output_lm = model_lm(**tokenized).logits[0,-1,no_id_yes_id] |
| 92 | + print(output_lm) |
| 93 | + print(output) |
| 94 | + assert torch.allclose(output_lm, output) |
| 95 | + print("done") |
| 96 | + |
| 97 | +if __name__ == "__main__": |
| 98 | + test_mxbai_rerank_v2() |
0 commit comments