Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions save_awq.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be in the main repo folder, probably src/rank_llm/scripts

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Converts and store AWQ-quantized model."""

import argparse
import json
import logging

import awq
import transformers

QUANT_CONFIG = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
}


def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="msp_open_ai_ada2_random_s5000_gpt4_da0_mr20_sampled_mix.jsonl",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to have rank zephy's training data or a subset of it as the default value of the calibration dataset?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably requires changes to the load dataset logic too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the file that @ronakice shared. Wasn't this one used for training?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is not the data that we used for finetuning rankzephyr, but I leave it to Ronak to decide if we want to you the training dataset or the one that we shared with you.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ronakice PTAL!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/u3/rpradeep/RankVicuna/data/msp_open_ai_ada2_random_s5000_gpt4_da0_mr20_sampled_mix.jsonl

This is the file I used to train RankZephyr @sahel-sh?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way, something is off in AWQ quantizing, i will advice against merging until this is properly sorted

help="Path to the calibration dataset.",
)
parser.add_argument(
"--model_path",
type=str,
default="castorini/rank_zephyr_7b_v1_full",
help="Path/slug to the original model.",
)
parser.add_argument(
"--quant_path",
type=str,
default="awq_rank_zephyr_7b_v1_full",
help="Path/slug where the quantized model is to be stored.",
)
args = parser.parse_args()
return args


def load_dataset(dataset: str):
"""Returns list of prompts for given dataset."""
with open(dataset, "r") as file:
data = json.load(file)
prompts = []
for content in data:
content = content["conversations"]
prompt = ""
for prompt_dict in content:
if prompt_dict["from"] == "system":
prompt += prompt_dict["value"] + "\n"
for prompt_dict in content:
if prompt_dict["from"] == "human":
prompt += prompt_dict["value"] + "\n"
for prompt_dict in content:
if prompt_dict["from"] == "gpt":
prompt += prompt_dict["value"]
prompts.append(prompt)
return prompts


def main():
"""Entry point of the script."""
args = parse_args()
model_path = args.model_path
quant_path = args.quant_path
dataset = args.dataset

# Load model
logging.info(f"Loading model from {model_path}.")
model = awq.AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)

logging.info(f"Starting AWQ with data {dataset}.")
model.quantize(
tokenizer=tokenizer,
quant_config=QUANT_CONFIG,
calib_data=load_dataset(dataset=dataset),
)

# Convert config into appropriate format.
quantization_config = transformers.AwqConfig(
bits=QUANT_CONFIG["w_bit"],
group_size=QUANT_CONFIG["q_group_size"],
zero_point=QUANT_CONFIG["zero_point"],
version=QUANT_CONFIG["version"].lower(),
).to_dict()
model.model.config.quantization_config = quantization_config

logging.info(f"Saving quantized model at {quant_path}.")
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()
16 changes: 15 additions & 1 deletion src/rank_llm/rerank/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Optional, Tuple

import torch
from awq import AutoAWQForCausalLM
from fastchat.model import get_conversation_template, load_model
from ftfy import fix_text
from transformers import AutoTokenizer
from transformers.generation import GenerationConfig

from rank_llm.rerank.rankllm import PromptMode, RankLLM
Expand Down Expand Up @@ -62,7 +64,19 @@ def __init__(
f"Unsupported prompt mode: {prompt_mode}. The only prompt mode currently supported is a slight variation of Rank_GPT prompt."
)
# ToDo: Make repetition_penalty configurable
self._llm, self._tokenizer = load_model(model, device=device, num_gpus=num_gpus)
if "awq" in model:
self._llm = AutoAWQForCausalLM.from_quantized(
model,
fuse_layers=True,
max_seq_len=context_size,
).to(0)
self._tokenizer = AutoTokenizer.from_pretrained(model)
else:
self._llm, self._tokenizer = load_model(
model,
device=device,
num_gpus=num_gpus,
)
self._variable_passages = variable_passages
self._window_size = window_size
self._system_message = system_message
Expand Down