1
1
import logging
2
2
import os
3
+ import threading
3
4
import time
4
5
from collections import deque
5
6
from concurrent .futures import ThreadPoolExecutor
7
+ from concurrent .futures import as_completed
6
8
7
9
from tqdm import tqdm
8
10
11
+ from eureka_ml_insights .configs .config import DataSetConfig , ModelConfig
9
12
from eureka_ml_insights .data_utils .data import DataReader , JsonLinesWriter
13
+ from eureka_ml_insights .models .models import Model
10
14
11
15
from .pipeline import Component
12
16
from .reserved_names import INFERENCE_RESERVED_NAMES
17
21
class Inference (Component ):
18
22
def __init__ (
19
23
self ,
20
- model_config ,
21
- data_config ,
24
+ model_config : ModelConfig ,
25
+ data_config : DataSetConfig ,
22
26
output_dir ,
23
27
resume_from = None ,
24
28
new_columns = None ,
@@ -39,14 +43,16 @@ def __init__(
39
43
chat_mode (bool): optional. If True, the model will be used in chat mode, where a history of messages will be maintained in "previous_messages" column.
40
44
"""
41
45
super ().__init__ (output_dir )
42
- self .model = model_config .class_name (** model_config .init_args )
46
+ self .model : Model = model_config .class_name (** model_config .init_args )
43
47
self .data_loader = data_config .class_name (** data_config .init_args )
44
- self .writer = JsonLinesWriter (os .path .join (output_dir , "inference_result.jsonl" ))
48
+ self .appender = JsonLinesWriter (os .path .join (output_dir , "inference_result.jsonl" ), mode = "a" )
45
49
46
50
self .resume_from = resume_from
47
51
if resume_from and not os .path .exists (resume_from ):
48
52
raise FileNotFoundError (f"File { resume_from } not found." )
49
53
self .new_columns = new_columns
54
+ self .pre_inf_results_df = None
55
+ self .last_uid = None
50
56
51
57
# rate limiting parameters
52
58
self .requests_per_minute = requests_per_minute
@@ -57,6 +63,8 @@ def __init__(
57
63
self .max_concurrent = max_concurrent
58
64
self .chat_mode = chat_mode
59
65
self .model .chat_mode = self .chat_mode
66
+ self .output_dir = output_dir
67
+ self .writer_lock = threading .Lock ()
60
68
61
69
@classmethod
62
70
def from_config (cls , config ):
@@ -168,89 +176,44 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
168
176
return data
169
177
170
178
def run (self ):
171
- if self .max_concurrent > 1 :
172
- self ._run_par ()
173
- else :
174
- self ._run ()
175
-
176
- def _run (self ):
177
- """sequential inference"""
178
179
if self .resume_from :
179
- pre_inf_results_df , last_uid = self .fetch_previous_inference_results ()
180
- with self .data_loader as loader :
181
- with self .writer as writer :
182
- for data , model_args , model_kwargs in tqdm (loader , desc = "Inference Progress:" ):
183
- if self .chat_mode and data .get ("is_valid" , True ) is False :
184
- continue
185
- if self .resume_from and (data ["uid" ] <= last_uid ):
186
- prev_result = self .retrieve_exisiting_result (data , pre_inf_results_df )
187
- if prev_result :
188
- writer .write (prev_result )
189
- continue
190
-
191
- # generate text from model (optionally at a limited rate)
192
- if self .requests_per_minute :
193
- while len (self .request_times ) >= self .requests_per_minute :
194
- # remove the oldest request time if it is older than the rate limit period
195
- if time .time () - self .request_times [0 ] > self .period :
196
- self .request_times .popleft ()
197
- else :
198
- # rate limit is reached, wait for a second
199
- time .sleep (1 )
200
- self .request_times .append (time .time ())
201
- response_dict = self .model .generate (* model_args , ** model_kwargs )
202
- self .validate_response_dict (response_dict )
203
- # write results
204
- data .update (response_dict )
205
- writer .write (data )
206
-
207
- def _run_par (self ):
208
- """parallel inference"""
209
- concurrent_inputs = []
210
- concurrent_metadata = []
211
- if self .resume_from :
212
- pre_inf_results_df , last_uid = self .fetch_previous_inference_results ()
213
- with self .data_loader as loader :
214
- with self .writer as writer :
215
- for data , model_args , model_kwargs in tqdm (loader , desc = "Inference Progress:" ):
216
- if self .chat_mode and data .get ("is_valid" , True ) is False :
217
- continue
218
- if self .resume_from and (data ["uid" ] <= last_uid ):
219
- prev_result = self .retrieve_exisiting_result (data , pre_inf_results_df )
220
- if prev_result :
221
- writer .write (prev_result )
222
- continue
223
-
224
- # if batch is ready for concurrent inference
225
- elif len (concurrent_inputs ) >= self .max_concurrent :
226
- with ThreadPoolExecutor (max_workers = self .max_concurrent ) as executor :
227
- self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
228
- concurrent_inputs = []
229
- concurrent_metadata = []
230
- # add data to batch for concurrent inference
231
- concurrent_inputs .append ((model_args , model_kwargs ))
232
- concurrent_metadata .append (data )
233
- # if data loader is exhausted but there are remaining data points that did not form a full batch
234
- if concurrent_inputs :
235
- with ThreadPoolExecutor (max_workers = self .max_concurrent ) as executor :
236
- self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
237
-
238
- def run_batch (self , concurrent_inputs , concurrent_metadata , writer , executor ):
239
- """Run a batch of inferences concurrently using ThreadPoolExecutor.
240
- args:
241
- concurrent_inputs (list): list of inputs to the model.generate function.
242
- concurrent_metadata (list): list of metadata corresponding to the inputs.
243
- writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
244
- executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
245
- """
246
-
247
- def sub_func (model_inputs ):
248
- return self .model .generate (* model_inputs [0 ], ** model_inputs [1 ])
249
-
250
- results = executor .map (sub_func , concurrent_inputs )
251
- for i , result in enumerate (results ):
252
- data , response_dict = concurrent_metadata [i ], result
253
- self .validate_response_dict (response_dict )
254
- # prepare results for writing
255
- data .update (response_dict )
256
- writer .write (data )
180
+ self .pre_inf_results_df , self .last_uid = self .fetch_previous_inference_results ()
181
+ with self .data_loader as loader , ThreadPoolExecutor (max_workers = self .max_concurrent ) as executor :
182
+ futures = [executor .submit (self ._run_single , record ) for record in loader ]
183
+ for future in tqdm (as_completed (futures ), total = len (loader ), mininterval = 2.0 , desc = "Inference Progress: " ):
184
+ result = future .result ()
185
+ if result :
186
+ self ._append_threadsafe (result )
187
+
188
+ def _append_threadsafe (self , data ):
189
+ with self .writer_lock :
190
+ with self .appender as appender :
191
+ appender .write (data )
192
+
193
+ def _run_single (self , record : tuple [dict , tuple , dict ]):
194
+ """Runs model.generate() with respect to a single element of the dataloader."""
195
+
196
+ data , model_args , model_kwargs = record
197
+ if self .chat_mode and data .get ("is_valid" , True ) is False :
198
+ return None
199
+ if self .resume_from and (data ["uid" ] <= self .last_uid ):
200
+ prev_result = self .retrieve_exisiting_result (data , self .pre_inf_results_df )
201
+ if prev_result :
202
+ return prev_result
203
+
204
+ # Rate limiter -- only for sequential inference
205
+ if self .requests_per_minute and self .max_concurrent == 1 :
206
+ while len (self .request_times ) >= self .requests_per_minute :
207
+ # remove the oldest request time if it is older than the rate limit period
208
+ if time .time () - self .request_times [0 ] > self .period :
209
+ self .request_times .popleft ()
210
+ else :
211
+ # rate limit is reached, wait for a second
212
+ time .sleep (1 )
213
+ self .request_times .append (time .time ())
214
+
215
+ response_dict = self .model .generate (* model_args , ** model_kwargs )
216
+ self .validate_response_dict (response_dict )
217
+ data .update (response_dict )
218
+ return data
219
+
0 commit comments