1- import time
2- from typing import Optional , Any
3- from prompting .utils .cleaners import CleanerPipeline
4- from prompting .llms .base_llm import BaseLLM
5- from transformers import AutoModelForCausalLM , AutoTokenizer , AwqConfig , pipeline
1+ from transformers import AutoModelForCausalLM , AutoTokenizer , pipeline
62from loguru import logger
73import random
84import numpy as np
95import torch
106from prompting .utils .timer import Timer
7+ from prompting .settings import settings
118
12-
13- class HF_LLM (BaseLLM ):
14- def __init__ (
15- self ,
16- llm : Any ,
17- system_prompt ,
18- max_new_tokens = 256 ,
19- temperature = 0.7 ,
20- top_p = 0.95 ,
21- ):
22- model_kwargs = {
23- "temperature" : temperature ,
24- "top_p" : top_p ,
25- "max_tokens" : max_new_tokens ,
26- }
27- super ().__init__ (llm , system_prompt , model_kwargs )
28-
29- # Keep track of generation data using messages and times
30- self .system_prompt = system_prompt
31- self .messages = [{"content" : self .system_prompt , "role" : "system" }] if self .system_prompt else []
32- self .times : list [float ] = [0 ]
33- self ._role_template = {
34- "system" : "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n {{{{ {} }}}}<|eot_id|>" ,
35- "user" : "<|start_header_id|>user<|end_header_id|>\n {{{{ {} }}}}<|eot_id|>" ,
36- "assistant" : "<|start_header_id|>assistant<|end_header_id|>\n {{{{ {} }}}}<|eot_id|>" ,
37- "end" : "<|start_header_id|>assistant<|end_header_id|>" ,
38- }
39-
40- def query_conversation (
41- self ,
42- messages : list [str ],
43- roles : list [str ],
44- cleaner : Optional [CleanerPipeline ] = None ,
45- ):
46- """Query LLM with the given lists of conversation history and roles
47-
48- Args:
49- messages (list[str]): List of messages in the conversation.
50- roles (list[str]): List of roles for each message.
51- cleaner (Optional[CleanerPipeline], optional): Cleaner pipeline to use, if any.
52- """
53- assert len (messages ) == len (roles ), "Length of messages and roles must be the same"
54- inputs : list [dict [str , Any ]] = [{"content" : self .system_prompt , "role" : "system" }]
55- for role , message in zip (roles , messages ):
56- inputs .append ({"content" : message , "role" : role })
57-
58- t0 = time .perf_counter ()
59- response = self .forward (messages = inputs )
60- response = self .clean_response (cleaner , response )
61- self .times .extend ((0 , time .perf_counter () - t0 ))
62- return response
63-
64- def query (
65- self ,
66- message : list [str ],
67- role : str = "user" ,
68- cleaner : CleanerPipeline = CleanerPipeline (),
69- ):
70- # Adds the message to the list of messages for tracking purposes, even though it's not used downstream
71- messages = self .messages + [{"content" : message , "role" : role }]
72-
73- t0 = time .time ()
74- response = self ._forward (messages = messages )
75- response = self .clean_response (cleaner , response )
76-
77- self .messages = messages
78- self .messages .append ({"content" : response , "role" : "assistant" })
79- self .times .extend ((0 , time .time () - t0 ))
80-
81- return response
82-
83- def _make_prompt (self , messages : list [dict [str , str ]]) -> str :
84- composed_prompt : list [str ] = []
85-
86- for message in messages :
87- role = message ["role" ]
88- if role not in self ._role_template :
89- continue
90- content = message ["content" ]
91- composed_prompt .append (self ._role_template [role ].format (content ))
92-
93- # Adds final tag indicating the assistant's turn
94- composed_prompt .append (self ._role_template ["end" ])
95- return "" .join (composed_prompt )
96-
97- def _forward (self , messages : list [dict [str , str ]]):
98- # make composed prompt from messages
99- composed_prompt = self ._make_prompt (messages )
100- response = self .llm .generate (
101- composed_prompt ,
102- max_length = self .model_kwargs ["max_tokens" ],
103- temperature = self .model_kwargs ["temperature" ],
104- top_p = self .model_kwargs ["top_p" ],
105- )[0 ]
106-
107- try :
108- logger .info (
109- f"{ self .__class__ .__name__ } generated the following output:\n { response ['generated_text' ].strip ()} "
110- )
111- except Exception as e :
112- logger .info (f"Response: { response } " )
113- logger .error (f"Error logging the response: { e } " )
114-
115- return response ["generated_text" ].strip ()
116-
117-
118- def set_random_seeds (seed = 42 ):
119- """
120- Set random seeds for reproducibility across all relevant libraries
121- """
122- if seed is not None :
123- random .seed (seed )
124- np .random .seed (seed )
125- torch .manual_seed (seed )
126- torch .manual_seed (seed )
127- torch .cuda .manual_seed_all (seed )
128- torch .backends .cudnn .deterministic = True
129- torch .backends .cudnn .benchmark = False
130-
131-
132- class ReproducibleHF :
133- def __init__ (self , model_id = "Qwen/Qwen2-0.5B" , tensor_parallel_size = 0 , seed = 42 , ** kwargs ):
9+ class ReproducibleHF ():
10+ def __init__ (self , model_id = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4" , settings = None , ** kwargs ):
13411 """
13512 Initialize Hugging Face model with reproducible settings and optimizations
13613 """
137- self .set_random_seeds (seed )
138-
139- # Load model and tokenizer with optimizations
140- model_kwargs = {
141- "device_map" : "auto" ,
142- }
143-
144- # get valid params for generation from model config
145- self .valid_generation_params = set (
146- AutoModelForCausalLM .from_pretrained (model_id ).generation_config .to_dict ().keys ()
147- )
148-
149- for k , v in kwargs .items ():
150- if k not in ["sampling_params" ]: # exclude sampling_params and any other generation-only args
151- model_kwargs [k ] = v
152-
153- quantization_config = AwqConfig (
154- bits = 4 ,
155- fuse_max_seq_len = 512 ,
156- do_fuse = True ,
157- )
158-
14+ self .seed = self .set_random_seeds (42 )
15+ quantization_config = settings .QUANTIZATION_CONFIG .get (model_id , None )
16+
15917 self .model = AutoModelForCausalLM .from_pretrained (
16018 model_id ,
16119 torch_dtype = torch .float16 ,
16220 low_cpu_mem_usage = True ,
163- device_map = "auto " ,
21+ device_map = "cuda:0 " ,
16422 quantization_config = quantization_config ,
16523 )
166-
24+
16725 self .tokenizer = AutoTokenizer .from_pretrained (model_id )
26+
27+ self .valid_generation_params = set (
28+ AutoModelForCausalLM .from_pretrained (model_id ).generation_config .to_dict ().keys ()
29+ )
16830
169- # self.model.generation_config.cache_implementation = "static"
170- # self.model.forward = torch.compile(self.model.forward, mode="reduce-overhead", fullgraph=True)
171- # self.valid_generation_params = set(self.model.generation_config.to_dict().keys())
172-
173- # Enable model optimizations
174- self .model .eval ()
175-
176- if tensor_parallel_size > 1 :
177- self .model = torch .nn .DataParallel (self .model , device_ids = list (range (tensor_parallel_size )))
178-
179- # Create pipeline with optimized settings
18031 self .llm = pipeline ("text-generation" , model = self .model , tokenizer = self .tokenizer )
18132
182- # Default sampling parameters
183- self .sampling_params = {
184- "temperature" : 0.7 ,
185- "top_p" : 0.95 ,
186- "top_k" : 50 ,
187- "max_new_tokens" : 256 ,
188- "presence_penalty" : 0 ,
189- "frequency_penalty" : 0 ,
190- "seed" : seed ,
191- "do_sample" : True ,
192- "early_stopping" : True , # Enable early stopping
193- "num_beams" : 1 , # Use greedy decoding by default
194- }
33+ self .sampling_params = settings .SAMPLING_PARAMS
19534
19635 @torch .inference_mode ()
19736 def generate (self , prompts , sampling_params = None ):
19837 """
19938 Generate text with optimized performance
20039 """
201-
202- # Convert single prompt to list
203- if isinstance (prompts , str ):
204- prompts = [prompts ]
205-
206- inputs = self .tokenizer (prompts , truncation = True , return_tensors = "pt" ).to (self .model .device )
40+
41+ inputs = self .tokenizer .apply_chat_template (
42+ prompts ,
43+ tokenize = True ,
44+ add_generation_prompt = True ,
45+ return_tensors = "pt" ,
46+ return_dict = True ,
47+ ).to (settings .NEURON_DEVICE )
20748
20849 params = sampling_params if sampling_params else self .sampling_params
20950 filtered_params = {k : v for k , v in params .items () if k in self .valid_generation_params }
@@ -215,9 +56,9 @@ def generate(self, prompts, sampling_params=None):
21556 ** filtered_params ,
21657 eos_token_id = self .tokenizer .eos_token_id ,
21758 )
218-
219- results = self .tokenizer . batch_decode ( outputs , skip_special_tokens = True , clean_up_tokenization_spaces = True )
220- results = [ text . strip () for text in results ]
59+
60+ outputs = self .model . generate ( ** inputs , ** filtered_params , eos_token_id = self . tokenizer . eos_token_id , )
61+ results = self . tokenizer . batch_decode ( outputs [:, inputs [ 'input_ids' ]. shape [ 1 ]:], skip_special_tokens = True , )[ 0 ]
22162
22263 logger .debug (
22364 f"PROMPT: { prompts } \n \n RESPONSES: { results } \n \n "
0 commit comments