Skip to content

Commit 07e6d9c

Browse files
committed
WIP: removed vllm, llm base class.
1 parent 1ae5e2b commit 07e6d9c

File tree

6 files changed

+36
-256
lines changed

6 files changed

+36
-256
lines changed

prompting/llms/base_llm.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

prompting/llms/hf_llm.py

Lines changed: 25 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,209 +1,50 @@
1-
import time
2-
from typing import Optional, Any
3-
from prompting.utils.cleaners import CleanerPipeline
4-
from prompting.llms.base_llm import BaseLLM
5-
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, pipeline
1+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
62
from loguru import logger
73
import random
84
import numpy as np
95
import torch
106
from prompting.utils.timer import Timer
7+
from prompting.settings import settings
118

12-
13-
class HF_LLM(BaseLLM):
14-
def __init__(
15-
self,
16-
llm: Any,
17-
system_prompt,
18-
max_new_tokens=256,
19-
temperature=0.7,
20-
top_p=0.95,
21-
):
22-
model_kwargs = {
23-
"temperature": temperature,
24-
"top_p": top_p,
25-
"max_tokens": max_new_tokens,
26-
}
27-
super().__init__(llm, system_prompt, model_kwargs)
28-
29-
# Keep track of generation data using messages and times
30-
self.system_prompt = system_prompt
31-
self.messages = [{"content": self.system_prompt, "role": "system"}] if self.system_prompt else []
32-
self.times: list[float] = [0]
33-
self._role_template = {
34-
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
35-
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
36-
"assistant": "<|start_header_id|>assistant<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
37-
"end": "<|start_header_id|>assistant<|end_header_id|>",
38-
}
39-
40-
def query_conversation(
41-
self,
42-
messages: list[str],
43-
roles: list[str],
44-
cleaner: Optional[CleanerPipeline] = None,
45-
):
46-
"""Query LLM with the given lists of conversation history and roles
47-
48-
Args:
49-
messages (list[str]): List of messages in the conversation.
50-
roles (list[str]): List of roles for each message.
51-
cleaner (Optional[CleanerPipeline], optional): Cleaner pipeline to use, if any.
52-
"""
53-
assert len(messages) == len(roles), "Length of messages and roles must be the same"
54-
inputs: list[dict[str, Any]] = [{"content": self.system_prompt, "role": "system"}]
55-
for role, message in zip(roles, messages):
56-
inputs.append({"content": message, "role": role})
57-
58-
t0 = time.perf_counter()
59-
response = self.forward(messages=inputs)
60-
response = self.clean_response(cleaner, response)
61-
self.times.extend((0, time.perf_counter() - t0))
62-
return response
63-
64-
def query(
65-
self,
66-
message: list[str],
67-
role: str = "user",
68-
cleaner: CleanerPipeline = CleanerPipeline(),
69-
):
70-
# Adds the message to the list of messages for tracking purposes, even though it's not used downstream
71-
messages = self.messages + [{"content": message, "role": role}]
72-
73-
t0 = time.time()
74-
response = self._forward(messages=messages)
75-
response = self.clean_response(cleaner, response)
76-
77-
self.messages = messages
78-
self.messages.append({"content": response, "role": "assistant"})
79-
self.times.extend((0, time.time() - t0))
80-
81-
return response
82-
83-
def _make_prompt(self, messages: list[dict[str, str]]) -> str:
84-
composed_prompt: list[str] = []
85-
86-
for message in messages:
87-
role = message["role"]
88-
if role not in self._role_template:
89-
continue
90-
content = message["content"]
91-
composed_prompt.append(self._role_template[role].format(content))
92-
93-
# Adds final tag indicating the assistant's turn
94-
composed_prompt.append(self._role_template["end"])
95-
return "".join(composed_prompt)
96-
97-
def _forward(self, messages: list[dict[str, str]]):
98-
# make composed prompt from messages
99-
composed_prompt = self._make_prompt(messages)
100-
response = self.llm.generate(
101-
composed_prompt,
102-
max_length=self.model_kwargs["max_tokens"],
103-
temperature=self.model_kwargs["temperature"],
104-
top_p=self.model_kwargs["top_p"],
105-
)[0]
106-
107-
try:
108-
logger.info(
109-
f"{self.__class__.__name__} generated the following output:\n{response['generated_text'].strip()}"
110-
)
111-
except Exception as e:
112-
logger.info(f"Response: {response}")
113-
logger.error(f"Error logging the response: {e}")
114-
115-
return response["generated_text"].strip()
116-
117-
118-
def set_random_seeds(seed=42):
119-
"""
120-
Set random seeds for reproducibility across all relevant libraries
121-
"""
122-
if seed is not None:
123-
random.seed(seed)
124-
np.random.seed(seed)
125-
torch.manual_seed(seed)
126-
torch.manual_seed(seed)
127-
torch.cuda.manual_seed_all(seed)
128-
torch.backends.cudnn.deterministic = True
129-
torch.backends.cudnn.benchmark = False
130-
131-
132-
class ReproducibleHF:
133-
def __init__(self, model_id="Qwen/Qwen2-0.5B", tensor_parallel_size=0, seed=42, **kwargs):
9+
class ReproducibleHF():
10+
def __init__(self, model_id="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", settings=None, **kwargs):
13411
"""
13512
Initialize Hugging Face model with reproducible settings and optimizations
13613
"""
137-
self.set_random_seeds(seed)
138-
139-
# Load model and tokenizer with optimizations
140-
model_kwargs = {
141-
"device_map": "auto",
142-
}
143-
144-
# get valid params for generation from model config
145-
self.valid_generation_params = set(
146-
AutoModelForCausalLM.from_pretrained(model_id).generation_config.to_dict().keys()
147-
)
148-
149-
for k, v in kwargs.items():
150-
if k not in ["sampling_params"]: # exclude sampling_params and any other generation-only args
151-
model_kwargs[k] = v
152-
153-
quantization_config = AwqConfig(
154-
bits=4,
155-
fuse_max_seq_len=512,
156-
do_fuse=True,
157-
)
158-
14+
self.seed = self.set_random_seeds(42)
15+
quantization_config = settings.QUANTIZATION_CONFIG.get(model_id, None)
16+
15917
self.model = AutoModelForCausalLM.from_pretrained(
16018
model_id,
16119
torch_dtype=torch.float16,
16220
low_cpu_mem_usage=True,
163-
device_map="auto",
21+
device_map="cuda:0",
16422
quantization_config=quantization_config,
16523
)
166-
24+
16725
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
26+
27+
self.valid_generation_params = set(
28+
AutoModelForCausalLM.from_pretrained(model_id).generation_config.to_dict().keys()
29+
)
16830

169-
# self.model.generation_config.cache_implementation = "static"
170-
# self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
171-
# self.valid_generation_params = set(self.model.generation_config.to_dict().keys())
172-
173-
# Enable model optimizations
174-
self.model.eval()
175-
176-
if tensor_parallel_size > 1:
177-
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(tensor_parallel_size)))
178-
179-
# Create pipeline with optimized settings
18031
self.llm = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
18132

182-
# Default sampling parameters
183-
self.sampling_params = {
184-
"temperature": 0.7,
185-
"top_p": 0.95,
186-
"top_k": 50,
187-
"max_new_tokens": 256,
188-
"presence_penalty": 0,
189-
"frequency_penalty": 0,
190-
"seed": seed,
191-
"do_sample": True,
192-
"early_stopping": True, # Enable early stopping
193-
"num_beams": 1, # Use greedy decoding by default
194-
}
33+
self.sampling_params = settings.SAMPLING_PARAMS
19534

19635
@torch.inference_mode()
19736
def generate(self, prompts, sampling_params=None):
19837
"""
19938
Generate text with optimized performance
20039
"""
201-
202-
# Convert single prompt to list
203-
if isinstance(prompts, str):
204-
prompts = [prompts]
205-
206-
inputs = self.tokenizer(prompts, truncation=True, return_tensors="pt").to(self.model.device)
40+
41+
inputs = self.tokenizer.apply_chat_template(
42+
prompts,
43+
tokenize=True,
44+
add_generation_prompt=True,
45+
return_tensors="pt",
46+
return_dict=True,
47+
).to(settings.NEURON_DEVICE)
20748

20849
params = sampling_params if sampling_params else self.sampling_params
20950
filtered_params = {k: v for k, v in params.items() if k in self.valid_generation_params}
@@ -215,9 +56,9 @@ def generate(self, prompts, sampling_params=None):
21556
**filtered_params,
21657
eos_token_id=self.tokenizer.eos_token_id,
21758
)
218-
219-
results = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
220-
results = [text.strip() for text in results]
59+
60+
outputs = self.model.generate(**inputs, **filtered_params, eos_token_id=self.tokenizer.eos_token_id,)
61+
results = self.tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True, )[0]
22162

22263
logger.debug(
22364
f"PROMPT: {prompts}\n\nRESPONSES: {results}\n\n"

prompting/llms/model_manager.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# through the tasks based on the currently loaded model.
1616
open_tasks = []
1717

18-
1918
class ModelManager(BaseModel):
2019
always_active_models: list[ModelConfig] = []
2120
total_ram: float = settings.LLM_MODEL_RAM
@@ -75,7 +74,6 @@ def load_model(self, model_config: ModelConfig, force: bool = True):
7574
self.active_models[model_config] = model
7675
self.used_ram += model_config.min_ram
7776
logger.info(f"Model {model_config.llm_model_id} loaded. Current used RAM: {self.used_ram} GB")
78-
7977
return model
8078
except Exception as e:
8179
logger.exception(f"Failed to load model {model_config.llm_model_id}. Error: {str(e)}")
@@ -148,14 +146,7 @@ def generate(
148146
model = ModelZoo.get_random(max_ram=self.total_ram)
149147

150148
model_instance: ReproducibleHF = self.get_model(model)
151-
152-
valid_args = {"max_length", "temperature", "top_p", "min_length", "do_sample", "num_return_sequences"}
153-
if sampling_params:
154-
sampling_params = {k: v for k, v in sampling_params.items() if k in valid_args}
155-
else:
156-
sampling_params = {"max_length": settings.NEURON_MAX_TOKENS}
157-
158-
responses = model_instance.generate(prompts=[composed_prompt], sampling_params=sampling_params)
149+
responses = model_instance.generate(prompts=[composed_prompt])
159150

160151
return responses
161152

prompting/settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from loguru import logger
99
from pydantic import Field, model_validator
1010
from pydantic_settings import BaseSettings
11+
from transformers import AwqConfig
1112

1213
from prompting.utils.config import config
1314

@@ -81,6 +82,7 @@ class Settings(BaseSettings):
8182
MAX_ALLOWED_VRAM_GB: int = Field(62, env="MAX_ALLOWED_VRAM_GB")
8283
LLM_MAX_MODEL_LEN: int = Field(4096, env="LLM_MAX_MODEL_LEN")
8384
LLM_MODEL: str = Field("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", env="LLM_MODEL")
85+
SAMPLING_PARAMS: dict[str, Any] = {"temperature": 0.7, "top_p": 0.95, "top_k": 50, "max_new_tokens": 256, "do_sample" : True, "seed": None}
8486
MINER_LLM_MODEL: Optional[str] = Field(None, env="MINER_LLM_MODEL")
8587
LLM_MODEL_RAM: float = Field(70, env="LLM_MODEL_RAM")
8688
OPENAI_API_KEY: str | None = Field(None, env="OPENAI_API_KEY")
@@ -178,6 +180,11 @@ def complete_settings(cls, values: dict[str, Any]) -> dict[str, Any]:
178180
"You must provide an OpenAI API key as a backup. It is recommended to also provide an SN19 API key + url to avoid incurring API costs."
179181
)
180182
return values
183+
184+
@cached_property
185+
def QUANTIZATION_CONFIG(self) -> AwqConfig:
186+
configs = {"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4" : AwqConfig(bits=4, fuse_max_seq_len=512, do_fuse=True)}
187+
return configs
181188

182189
@cached_property
183190
def WALLET(self) -> bt.wallet:
@@ -206,6 +213,7 @@ def METAGRAPH(self) -> bt.metagraph:
206213
def DENDRITE(self) -> bt.dendrite:
207214
logger.info(f"Instantiating dendrite with wallet: {self.WALLET}")
208215
return bt.dendrite(wallet=self.WALLET)
216+
209217

210218

211219
settings: Optional[Settings] = None

prompting/tasks/base_task.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from loguru import logger
44
from abc import ABC
55
from pydantic import BaseModel, Field, ConfigDict, model_validator
6-
from prompting.llms.vllm_llm import vLLM_LLM
76
from prompting.utils.cleaners import CleanerPipeline
87
from typing import ClassVar
98
from prompting.datasets.base import DatasetEntry
@@ -78,12 +77,10 @@ def make_reference(self, dataset_entry: DatasetEntry) -> str:
7877
def generate_reference(self, messages: list[str]) -> str:
7978
"""Generates a reference answer to be used for scoring miner completions"""
8079
logger.info("🤖 Generating reference...")
81-
self.reference = vLLM_LLM(
82-
llm=model_manager.get_model(self.llm_model).llm, system_prompt=self.reference_system_prompt or ""
83-
).query(cleaner=self.cleaner, message=messages)
84-
# self.reference = model_manager.get_model(self.llm_model).generate(prompts=messages)
80+
self.reference = model_manager.get_model(settings.LLM_MODEL).generate(prompts=messages)
8581
if self.reference is None:
8682
raise Exception("Reference generation failed")
83+
8784
return self.reference
8885

8986
def generate_query(

0 commit comments

Comments
 (0)