Skip to content

Commit f2c7243

Browse files
author
richwardle
committed
WIP: inference time needs improving
1 parent c421826 commit f2c7243

File tree

6 files changed

+437
-160
lines changed

6 files changed

+437
-160
lines changed

poetry.lock

Lines changed: 145 additions & 132 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

prompting/llms/hf_llm.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import time
2+
from typing import Optional, Any
3+
from prompting.utils.cleaners import CleanerPipeline
4+
from prompting.llms.base_llm import BaseLLM
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, pipeline
6+
from loguru import logger
7+
import random
8+
import numpy as np
9+
import torch
10+
from prompting.utils.timer import Timer
11+
12+
13+
class HF_LLM(BaseLLM):
14+
def __init__(
15+
self,
16+
llm: Any,
17+
system_prompt,
18+
max_new_tokens=256,
19+
temperature=0.7,
20+
top_p=0.95,
21+
):
22+
model_kwargs = {
23+
"temperature": temperature,
24+
"top_p": top_p,
25+
"max_tokens": max_new_tokens,
26+
}
27+
super().__init__(llm, system_prompt, model_kwargs)
28+
29+
# Keep track of generation data using messages and times
30+
self.system_prompt = system_prompt
31+
self.messages = [{"content": self.system_prompt, "role": "system"}] if self.system_prompt else []
32+
self.times: list[float] = [0]
33+
self._role_template = {
34+
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
35+
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
36+
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
37+
"end": "<|start_header_id|>assistant<|end_header_id|>",
38+
}
39+
40+
def query_conversation(
41+
self,
42+
messages: list[str],
43+
roles: list[str],
44+
cleaner: Optional[CleanerPipeline] = None,
45+
):
46+
"""Query LLM with the given lists of conversation history and roles
47+
48+
Args:
49+
messages (list[str]): List of messages in the conversation.
50+
roles (list[str]): List of roles for each message.
51+
cleaner (Optional[CleanerPipeline], optional): Cleaner pipeline to use, if any.
52+
"""
53+
assert len(messages) == len(roles), "Length of messages and roles must be the same"
54+
inputs: list[dict[str, Any]] = [{"content": self.system_prompt, "role": "system"}]
55+
for role, message in zip(roles, messages):
56+
inputs.append({"content": message, "role": role})
57+
58+
t0 = time.perf_counter()
59+
response = self.forward(messages=inputs)
60+
response = self.clean_response(cleaner, response)
61+
self.times.extend((0, time.perf_counter() - t0))
62+
return response
63+
64+
def query(
65+
self,
66+
message: list[str],
67+
role: str = "user",
68+
cleaner: CleanerPipeline = CleanerPipeline(),
69+
):
70+
# Adds the message to the list of messages for tracking purposes, even though it's not used downstream
71+
messages = self.messages + [{"content": message, "role": role}]
72+
73+
t0 = time.time()
74+
response = self._forward(messages=messages)
75+
response = self.clean_response(cleaner, response)
76+
77+
self.messages = messages
78+
self.messages.append({"content": response, "role": "assistant"})
79+
self.times.extend((0, time.time() - t0))
80+
81+
return response
82+
83+
def _make_prompt(self, messages: list[dict[str, str]]) -> str:
84+
composed_prompt: list[str] = []
85+
86+
for message in messages:
87+
role = message["role"]
88+
if role not in self._role_template:
89+
continue
90+
content = message["content"]
91+
composed_prompt.append(self._role_template[role].format(content))
92+
93+
# Adds final tag indicating the assistant's turn
94+
composed_prompt.append(self._role_template["end"])
95+
return "".join(composed_prompt)
96+
97+
def _forward(self, messages: list[dict[str, str]]):
98+
# make composed prompt from messages
99+
composed_prompt = self._make_prompt(messages)
100+
response = self.llm.generate(
101+
composed_prompt,
102+
max_length=self.model_kwargs["max_tokens"],
103+
temperature=self.model_kwargs["temperature"],
104+
top_p=self.model_kwargs["top_p"],
105+
)[0]
106+
107+
try:
108+
logger.info(
109+
f"{self.__class__.__name__} generated the following output:\n{response['generated_text'].strip()}"
110+
)
111+
except Exception as e:
112+
logger.info(f"Response: {response}")
113+
logger.error(f"Error logging the response: {e}")
114+
115+
return response["generated_text"].strip()
116+
117+
118+
def set_random_seeds(seed=42):
119+
"""
120+
Set random seeds for reproducibility across all relevant libraries
121+
"""
122+
if seed is not None:
123+
random.seed(seed)
124+
np.random.seed(seed)
125+
torch.manual_seed(seed)
126+
torch.manual_seed(seed)
127+
torch.cuda.manual_seed_all(seed)
128+
torch.backends.cudnn.deterministic = True
129+
torch.backends.cudnn.benchmark = False
130+
131+
132+
class ReproducibleHF:
133+
def __init__(self, model_id="Qwen/Qwen2-0.5B", tensor_parallel_size=0, seed=42, **kwargs):
134+
"""
135+
Initialize Hugging Face model with reproducible settings and optimizations
136+
"""
137+
self.set_random_seeds(seed)
138+
139+
# Load model and tokenizer with optimizations
140+
model_kwargs = {
141+
"device_map": "auto",
142+
}
143+
144+
# get valid params for generation from model config
145+
self.valid_generation_params = set(
146+
AutoModelForCausalLM.from_pretrained(model_id).generation_config.to_dict().keys()
147+
)
148+
149+
for k, v in kwargs.items():
150+
if k not in ["sampling_params"]: # exclude sampling_params and any other generation-only args
151+
model_kwargs[k] = v
152+
153+
quantization_config = AwqConfig(
154+
bits=4,
155+
fuse_max_seq_len=512,
156+
do_fuse=True,
157+
)
158+
159+
self.model = AutoModelForCausalLM.from_pretrained(
160+
model_id,
161+
torch_dtype=torch.float16,
162+
low_cpu_mem_usage=True,
163+
device_map="auto",
164+
quantization_config=quantization_config,
165+
)
166+
167+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
168+
169+
# self.model.generation_config.cache_implementation = "static"
170+
# self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
171+
# self.valid_generation_params = set(self.model.generation_config.to_dict().keys())
172+
173+
# Enable model optimizations
174+
self.model.eval()
175+
176+
if tensor_parallel_size > 1:
177+
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(tensor_parallel_size)))
178+
179+
# Create pipeline with optimized settings
180+
self.llm = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
181+
182+
# Default sampling parameters
183+
self.sampling_params = {
184+
"temperature": 0.7,
185+
"top_p": 0.95,
186+
"top_k": 50,
187+
"max_new_tokens": 256,
188+
"presence_penalty": 0,
189+
"frequency_penalty": 0,
190+
"seed": seed,
191+
"do_sample": True,
192+
"early_stopping": True, # Enable early stopping
193+
"num_beams": 1, # Use greedy decoding by default
194+
}
195+
196+
@torch.inference_mode()
197+
def generate(self, prompts, sampling_params=None):
198+
"""
199+
Generate text with optimized performance
200+
"""
201+
202+
# Convert single prompt to list
203+
if isinstance(prompts, str):
204+
prompts = [prompts]
205+
206+
inputs = self.tokenizer(prompts, truncation=True, return_tensors="pt").to(self.model.device)
207+
208+
params = sampling_params if sampling_params else self.sampling_params
209+
filtered_params = {k: v for k, v in params.items() if k in self.valid_generation_params}
210+
211+
with Timer() as timer:
212+
# Generate with optimized settings
213+
outputs = self.model.generate(
214+
**inputs,
215+
**filtered_params,
216+
eos_token_id=self.tokenizer.eos_token_id,
217+
)
218+
219+
results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
220+
results = [text.strip() for text in results]
221+
222+
logger.debug(
223+
f"PROMPT: {prompts}\n\nRESPONSES: {results}\n\n"
224+
f"SAMPLING PARAMS: {params}\n\n"
225+
f"TIME FOR RESPONSE: {timer.elapsed_time}"
226+
)
227+
228+
return results if len(results) > 1 else results[0]
229+
230+
def set_random_seeds(self, seed=42):
231+
"""
232+
Set random seeds for reproducibility across all relevant libraries
233+
"""
234+
if seed is not None:
235+
random.seed(seed)
236+
np.random.seed(seed)
237+
torch.manual_seed(seed)
238+
if torch.cuda.is_available():
239+
torch.cuda.manual_seed_all(seed)
240+
torch.backends.cudnn.deterministic = True
241+
torch.backends.cudnn.benchmark = False
242+
243+
244+
if __name__ == "__main__":
245+
llm = ReproducibleHF(model="Qwen/Qwen2-0.5B", tensor_parallel_size=1, seed=42)
246+
llm.generate("Hello, world!")

prompting/llms/model_manager.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from prompting.settings import settings
1212
from vllm.sampling_params import SamplingParams
1313
from prompting.llms.vllm_llm import ReproducibleVLLM
14+
from prompting.llms.hf_llm import ReproducibleHF
1415

1516
# This maintains a list of tasks for which we need to generate references. Since
1617
# we can only generate the references, when the correct model is loaded, we work
@@ -67,16 +68,21 @@ def load_model(self, model_config: ModelConfig, force: bool = True):
6768
f"Loading model... {model_config.llm_model_id} with GPU Utilization: {model_config.min_ram / GPUInfo.free_memory}"
6869
)
6970
GPUInfo.log_gpu_info()
70-
# model = vllm.LLM(
71-
# model_config.llm_model_id,
72-
# max_model_len=8_000,
73-
# gpu_memory_utilization=model_config.min_ram / GPUInfo.free_memory,
74-
# )
75-
model = ReproducibleVLLM(
76-
model=model_config.llm_model_id,
77-
gpu_memory_utilization=model_config.min_ram / GPUInfo.free_memory,
78-
max_model_len=settings.LLM_MAX_MODEL_LEN,
79-
)
71+
72+
if settings.LLM_TYPE == "hf":
73+
model = ReproducibleHF(
74+
model=model_config.llm_model_id,
75+
gpu_memory_utilization=model_config.min_ram / GPUInfo.free_memory,
76+
max_model_len=settings.LLM_MAX_MODEL_LEN,
77+
)
78+
elif settings.LLM_TYPE == "vllm":
79+
model = ReproducibleVLLM(
80+
model=model_config.llm_model_id,
81+
gpu_memory_utilization=model_config.min_ram / GPUInfo.free_memory,
82+
max_model_len=settings.LLM_MAX_MODEL_LEN,
83+
)
84+
else:
85+
raise ValueError(f"Unknown LLM_TYPE: {settings.LLM_TYPE}")
8086

8187
self.active_models[model_config] = model
8288
self.used_ram += model_config.min_ram
@@ -101,13 +107,13 @@ def unload_model(self, model_config: ModelConfig):
101107
self.used_ram -= model_config.min_ram
102108
torch.cuda.empty_cache()
103109

104-
def get_or_load_model(self, llm_model_id: str) -> ReproducibleVLLM:
110+
def get_or_load_model(self, llm_model_id: str) -> ReproducibleVLLM | ReproducibleHF:
105111
model_config = ModelZoo.get_model_by_id(llm_model_id)
106112
if model_config not in self.active_models:
107113
self.load_model(model_config)
108114
return self.active_models[model_config]
109115

110-
def get_model(self, llm_model: ModelConfig | str) -> ReproducibleVLLM:
116+
def get_model(self, llm_model: ModelConfig | str) -> ReproducibleVLLM | ReproducibleHF:
111117
if not llm_model:
112118
llm_model = list(self.active_models.keys())[0] if self.active_models else ModelZoo.get_random()
113119
if isinstance(llm_model, str):
@@ -144,7 +150,7 @@ def generate(
144150
messages: list[str],
145151
roles: list[str],
146152
model: ModelConfig | str | None = None,
147-
sampling_params: SamplingParams | None = SamplingParams(max_tokens=settings.NEURON_MAX_TOKENS),
153+
sampling_params: SamplingParams | None = None,
148154
) -> str:
149155
dict_messages = [{"content": message, "role": role} for message, role in zip(messages, roles)]
150156
composed_prompt = self._make_prompt(dict_messages)
@@ -155,7 +161,16 @@ def generate(
155161
if not model:
156162
model = ModelZoo.get_random(max_ram=self.total_ram)
157163

158-
model_instance: ReproducibleVLLM = self.get_model(model)
164+
model_instance: ReproducibleVLLM | ReproducibleHF = self.get_model(model)
165+
166+
# Adjust sampling_params based on LLM_TYPE
167+
if settings.LLM_TYPE == "hf":
168+
valid_args = {"max_length", "temperature", "top_p", "min_length", "do_sample", "num_return_sequences"}
169+
if sampling_params:
170+
sampling_params = {k: v for k, v in sampling_params.items() if k in valid_args}
171+
else:
172+
sampling_params = {"max_length": settings.NEURON_MAX_TOKENS}
173+
159174
responses = model_instance.generate(prompts=[composed_prompt], sampling_params=sampling_params)
160175

161176
return responses

0 commit comments

Comments
 (0)