|
| 1 | +import time |
| 2 | +from typing import Optional, Any |
| 3 | +from prompting.utils.cleaners import CleanerPipeline |
| 4 | +from prompting.llms.base_llm import BaseLLM |
| 5 | +from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, pipeline |
| 6 | +from loguru import logger |
| 7 | +import random |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from prompting.utils.timer import Timer |
| 11 | + |
| 12 | + |
| 13 | +class HF_LLM(BaseLLM): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + llm: Any, |
| 17 | + system_prompt, |
| 18 | + max_new_tokens=256, |
| 19 | + temperature=0.7, |
| 20 | + top_p=0.95, |
| 21 | + ): |
| 22 | + model_kwargs = { |
| 23 | + "temperature": temperature, |
| 24 | + "top_p": top_p, |
| 25 | + "max_tokens": max_new_tokens, |
| 26 | + } |
| 27 | + super().__init__(llm, system_prompt, model_kwargs) |
| 28 | + |
| 29 | + # Keep track of generation data using messages and times |
| 30 | + self.system_prompt = system_prompt |
| 31 | + self.messages = [{"content": self.system_prompt, "role": "system"}] if self.system_prompt else [] |
| 32 | + self.times: list[float] = [0] |
| 33 | + self._role_template = { |
| 34 | + "system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>", |
| 35 | + "user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>", |
| 36 | + "assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>", |
| 37 | + "end": "<|start_header_id|>assistant<|end_header_id|>", |
| 38 | + } |
| 39 | + |
| 40 | + def query_conversation( |
| 41 | + self, |
| 42 | + messages: list[str], |
| 43 | + roles: list[str], |
| 44 | + cleaner: Optional[CleanerPipeline] = None, |
| 45 | + ): |
| 46 | + """Query LLM with the given lists of conversation history and roles |
| 47 | +
|
| 48 | + Args: |
| 49 | + messages (list[str]): List of messages in the conversation. |
| 50 | + roles (list[str]): List of roles for each message. |
| 51 | + cleaner (Optional[CleanerPipeline], optional): Cleaner pipeline to use, if any. |
| 52 | + """ |
| 53 | + assert len(messages) == len(roles), "Length of messages and roles must be the same" |
| 54 | + inputs: list[dict[str, Any]] = [{"content": self.system_prompt, "role": "system"}] |
| 55 | + for role, message in zip(roles, messages): |
| 56 | + inputs.append({"content": message, "role": role}) |
| 57 | + |
| 58 | + t0 = time.perf_counter() |
| 59 | + response = self.forward(messages=inputs) |
| 60 | + response = self.clean_response(cleaner, response) |
| 61 | + self.times.extend((0, time.perf_counter() - t0)) |
| 62 | + return response |
| 63 | + |
| 64 | + def query( |
| 65 | + self, |
| 66 | + message: list[str], |
| 67 | + role: str = "user", |
| 68 | + cleaner: CleanerPipeline = CleanerPipeline(), |
| 69 | + ): |
| 70 | + # Adds the message to the list of messages for tracking purposes, even though it's not used downstream |
| 71 | + messages = self.messages + [{"content": message, "role": role}] |
| 72 | + |
| 73 | + t0 = time.time() |
| 74 | + response = self._forward(messages=messages) |
| 75 | + response = self.clean_response(cleaner, response) |
| 76 | + |
| 77 | + self.messages = messages |
| 78 | + self.messages.append({"content": response, "role": "assistant"}) |
| 79 | + self.times.extend((0, time.time() - t0)) |
| 80 | + |
| 81 | + return response |
| 82 | + |
| 83 | + def _make_prompt(self, messages: list[dict[str, str]]) -> str: |
| 84 | + composed_prompt: list[str] = [] |
| 85 | + |
| 86 | + for message in messages: |
| 87 | + role = message["role"] |
| 88 | + if role not in self._role_template: |
| 89 | + continue |
| 90 | + content = message["content"] |
| 91 | + composed_prompt.append(self._role_template[role].format(content)) |
| 92 | + |
| 93 | + # Adds final tag indicating the assistant's turn |
| 94 | + composed_prompt.append(self._role_template["end"]) |
| 95 | + return "".join(composed_prompt) |
| 96 | + |
| 97 | + def _forward(self, messages: list[dict[str, str]]): |
| 98 | + # make composed prompt from messages |
| 99 | + composed_prompt = self._make_prompt(messages) |
| 100 | + response = self.llm.generate( |
| 101 | + composed_prompt, |
| 102 | + max_length=self.model_kwargs["max_tokens"], |
| 103 | + temperature=self.model_kwargs["temperature"], |
| 104 | + top_p=self.model_kwargs["top_p"], |
| 105 | + )[0] |
| 106 | + |
| 107 | + try: |
| 108 | + logger.info( |
| 109 | + f"{self.__class__.__name__} generated the following output:\n{response['generated_text'].strip()}" |
| 110 | + ) |
| 111 | + except Exception as e: |
| 112 | + logger.info(f"Response: {response}") |
| 113 | + logger.error(f"Error logging the response: {e}") |
| 114 | + |
| 115 | + return response["generated_text"].strip() |
| 116 | + |
| 117 | + |
| 118 | +def set_random_seeds(seed=42): |
| 119 | + """ |
| 120 | + Set random seeds for reproducibility across all relevant libraries |
| 121 | + """ |
| 122 | + if seed is not None: |
| 123 | + random.seed(seed) |
| 124 | + np.random.seed(seed) |
| 125 | + torch.manual_seed(seed) |
| 126 | + torch.manual_seed(seed) |
| 127 | + torch.cuda.manual_seed_all(seed) |
| 128 | + torch.backends.cudnn.deterministic = True |
| 129 | + torch.backends.cudnn.benchmark = False |
| 130 | + |
| 131 | + |
| 132 | +class ReproducibleHF: |
| 133 | + def __init__(self, model_id="Qwen/Qwen2-0.5B", tensor_parallel_size=0, seed=42, **kwargs): |
| 134 | + """ |
| 135 | + Initialize Hugging Face model with reproducible settings and optimizations |
| 136 | + """ |
| 137 | + self.set_random_seeds(seed) |
| 138 | + |
| 139 | + # Load model and tokenizer with optimizations |
| 140 | + model_kwargs = { |
| 141 | + "device_map": "auto", |
| 142 | + } |
| 143 | + |
| 144 | + # get valid params for generation from model config |
| 145 | + self.valid_generation_params = set( |
| 146 | + AutoModelForCausalLM.from_pretrained(model_id).generation_config.to_dict().keys() |
| 147 | + ) |
| 148 | + |
| 149 | + for k, v in kwargs.items(): |
| 150 | + if k not in ["sampling_params"]: # exclude sampling_params and any other generation-only args |
| 151 | + model_kwargs[k] = v |
| 152 | + |
| 153 | + quantization_config = AwqConfig( |
| 154 | + bits=4, |
| 155 | + fuse_max_seq_len=512, |
| 156 | + do_fuse=True, |
| 157 | + ) |
| 158 | + |
| 159 | + self.model = AutoModelForCausalLM.from_pretrained( |
| 160 | + model_id, |
| 161 | + torch_dtype=torch.float16, |
| 162 | + low_cpu_mem_usage=True, |
| 163 | + device_map="auto", |
| 164 | + quantization_config=quantization_config, |
| 165 | + ) |
| 166 | + |
| 167 | + self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 168 | + |
| 169 | + # self.model.generation_config.cache_implementation = "static" |
| 170 | + # self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True) |
| 171 | + # self.valid_generation_params = set(self.model.generation_config.to_dict().keys()) |
| 172 | + |
| 173 | + # Enable model optimizations |
| 174 | + self.model.eval() |
| 175 | + |
| 176 | + if tensor_parallel_size > 1: |
| 177 | + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(tensor_parallel_size))) |
| 178 | + |
| 179 | + # Create pipeline with optimized settings |
| 180 | + self.llm = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer) |
| 181 | + |
| 182 | + # Default sampling parameters |
| 183 | + self.sampling_params = { |
| 184 | + "temperature": 0.7, |
| 185 | + "top_p": 0.95, |
| 186 | + "top_k": 50, |
| 187 | + "max_new_tokens": 256, |
| 188 | + "presence_penalty": 0, |
| 189 | + "frequency_penalty": 0, |
| 190 | + "seed": seed, |
| 191 | + "do_sample": True, |
| 192 | + "early_stopping": True, # Enable early stopping |
| 193 | + "num_beams": 1, # Use greedy decoding by default |
| 194 | + } |
| 195 | + |
| 196 | + @torch.inference_mode() |
| 197 | + def generate(self, prompts, sampling_params=None): |
| 198 | + """ |
| 199 | + Generate text with optimized performance |
| 200 | + """ |
| 201 | + |
| 202 | + # Convert single prompt to list |
| 203 | + if isinstance(prompts, str): |
| 204 | + prompts = [prompts] |
| 205 | + |
| 206 | + inputs = self.tokenizer(prompts, truncation=True, return_tensors="pt").to(self.model.device) |
| 207 | + |
| 208 | + params = sampling_params if sampling_params else self.sampling_params |
| 209 | + filtered_params = {k: v for k, v in params.items() if k in self.valid_generation_params} |
| 210 | + |
| 211 | + with Timer() as timer: |
| 212 | + # Generate with optimized settings |
| 213 | + outputs = self.model.generate( |
| 214 | + **inputs, |
| 215 | + **filtered_params, |
| 216 | + eos_token_id=self.tokenizer.eos_token_id, |
| 217 | + ) |
| 218 | + |
| 219 | + results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| 220 | + results = [text.strip() for text in results] |
| 221 | + |
| 222 | + logger.debug( |
| 223 | + f"PROMPT: {prompts}\n\nRESPONSES: {results}\n\n" |
| 224 | + f"SAMPLING PARAMS: {params}\n\n" |
| 225 | + f"TIME FOR RESPONSE: {timer.elapsed_time}" |
| 226 | + ) |
| 227 | + |
| 228 | + return results if len(results) > 1 else results[0] |
| 229 | + |
| 230 | + def set_random_seeds(self, seed=42): |
| 231 | + """ |
| 232 | + Set random seeds for reproducibility across all relevant libraries |
| 233 | + """ |
| 234 | + if seed is not None: |
| 235 | + random.seed(seed) |
| 236 | + np.random.seed(seed) |
| 237 | + torch.manual_seed(seed) |
| 238 | + if torch.cuda.is_available(): |
| 239 | + torch.cuda.manual_seed_all(seed) |
| 240 | + torch.backends.cudnn.deterministic = True |
| 241 | + torch.backends.cudnn.benchmark = False |
| 242 | + |
| 243 | + |
| 244 | +if __name__ == "__main__": |
| 245 | + llm = ReproducibleHF(model="Qwen/Qwen2-0.5B", tensor_parallel_size=1, seed=42) |
| 246 | + llm.generate("Hello, world!") |
0 commit comments