1
- import asyncio
2
1
import logging
3
2
import os
4
3
import time
@@ -170,7 +169,7 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
170
169
171
170
def run (self ):
172
171
if self .max_concurrent > 1 :
173
- asyncio . run ( self ._run_par () )
172
+ self ._run_par ()
174
173
else :
175
174
self ._run ()
176
175
@@ -205,23 +204,7 @@ def _run(self):
205
204
data .update (response_dict )
206
205
writer .write (data )
207
206
208
- from functools import partial
209
-
210
- async def run_in_excutor (self , model_inputs , executor ):
211
- """Run model.generate in a ThreadPoolExecutor.
212
- args:
213
- model_inputs (tuple): args and kwargs to be passed to the model.generate function.
214
- executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
215
- """
216
- loop = asyncio .get_event_loop ()
217
-
218
- # function to run in executor with args and kwargs
219
- def sub_func (model_inputs ):
220
- return self .model .generate (* model_inputs [0 ], ** model_inputs [1 ])
221
-
222
- return await loop .run_in_executor (executor , sub_func , model_inputs )
223
-
224
- async def _run_par (self ):
207
+ def _run_par (self ):
225
208
"""parallel inference"""
226
209
concurrent_inputs = []
227
210
concurrent_metadata = []
@@ -240,30 +223,33 @@ async def _run_par(self):
240
223
241
224
# if batch is ready for concurrent inference
242
225
elif len (concurrent_inputs ) >= self .max_concurrent :
243
- with ThreadPoolExecutor () as executor :
244
- await self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
226
+ with ThreadPoolExecutor (max_workers = self . max_concurrent ) as executor :
227
+ self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
245
228
concurrent_inputs = []
246
229
concurrent_metadata = []
247
230
# add data to batch for concurrent inference
248
231
concurrent_inputs .append ((model_args , model_kwargs ))
249
232
concurrent_metadata .append (data )
250
233
# if data loader is exhausted but there are remaining data points that did not form a full batch
251
234
if concurrent_inputs :
252
- with ThreadPoolExecutor () as executor :
253
- await self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
235
+ with ThreadPoolExecutor (max_workers = self . max_concurrent ) as executor :
236
+ self .run_batch (concurrent_inputs , concurrent_metadata , writer , executor )
254
237
255
- async def run_batch (self , concurrent_inputs , concurrent_metadata , writer , executor ):
238
+ def run_batch (self , concurrent_inputs , concurrent_metadata , writer , executor ):
256
239
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
257
240
args:
258
241
concurrent_inputs (list): list of inputs to the model.generate function.
259
242
concurrent_metadata (list): list of metadata corresponding to the inputs.
260
243
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
261
244
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
262
245
"""
263
- tasks = [asyncio .create_task (self .run_in_excutor (input_data , executor )) for input_data in concurrent_inputs ]
264
- results = await asyncio .gather (* tasks )
265
- for i in range (len (concurrent_inputs )):
266
- data , response_dict = concurrent_metadata [i ], results [i ]
246
+
247
+ def sub_func (model_inputs ):
248
+ return self .model .generate (* model_inputs [0 ], ** model_inputs [1 ])
249
+
250
+ results = executor .map (sub_func , concurrent_inputs )
251
+ for i , result in enumerate (results ):
252
+ data , response_dict = concurrent_metadata [i ], result
267
253
self .validate_response_dict (response_dict )
268
254
# prepare results for writing
269
255
data .update (response_dict )
0 commit comments