Skip to content

Commit 9bfdcd3

Browse files
safooraySafoora Yousefi
and
Safoora Yousefi
authored
Addressing concurrency issues (#105)
1- passes the max_workers argument to ThreadPoolExecutor to enable large number of threads 2- Removes asyncio usage, solely relies on ThreadPoolExecutor for simplicity. 3- Passes a separate instance of Model to each thread for data safety, even though through extensive testing I was not able to produce any data corruption issues. --------- Co-authored-by: Safoora Yousefi <[email protected]>
1 parent 148cc01 commit 9bfdcd3

File tree

2 files changed

+243
-182
lines changed

2 files changed

+243
-182
lines changed

eureka_ml_insights/core/inference.py

+14-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import logging
32
import os
43
import time
@@ -170,7 +169,7 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
170169

171170
def run(self):
172171
if self.max_concurrent > 1:
173-
asyncio.run(self._run_par())
172+
self._run_par()
174173
else:
175174
self._run()
176175

@@ -205,23 +204,7 @@ def _run(self):
205204
data.update(response_dict)
206205
writer.write(data)
207206

208-
from functools import partial
209-
210-
async def run_in_excutor(self, model_inputs, executor):
211-
"""Run model.generate in a ThreadPoolExecutor.
212-
args:
213-
model_inputs (tuple): args and kwargs to be passed to the model.generate function.
214-
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
215-
"""
216-
loop = asyncio.get_event_loop()
217-
218-
# function to run in executor with args and kwargs
219-
def sub_func(model_inputs):
220-
return self.model.generate(*model_inputs[0], **model_inputs[1])
221-
222-
return await loop.run_in_executor(executor, sub_func, model_inputs)
223-
224-
async def _run_par(self):
207+
def _run_par(self):
225208
"""parallel inference"""
226209
concurrent_inputs = []
227210
concurrent_metadata = []
@@ -240,30 +223,33 @@ async def _run_par(self):
240223

241224
# if batch is ready for concurrent inference
242225
elif len(concurrent_inputs) >= self.max_concurrent:
243-
with ThreadPoolExecutor() as executor:
244-
await self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
226+
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
227+
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
245228
concurrent_inputs = []
246229
concurrent_metadata = []
247230
# add data to batch for concurrent inference
248231
concurrent_inputs.append((model_args, model_kwargs))
249232
concurrent_metadata.append(data)
250233
# if data loader is exhausted but there are remaining data points that did not form a full batch
251234
if concurrent_inputs:
252-
with ThreadPoolExecutor() as executor:
253-
await self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
235+
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
236+
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
254237

255-
async def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
238+
def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
256239
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
257240
args:
258241
concurrent_inputs (list): list of inputs to the model.generate function.
259242
concurrent_metadata (list): list of metadata corresponding to the inputs.
260243
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
261244
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
262245
"""
263-
tasks = [asyncio.create_task(self.run_in_excutor(input_data, executor)) for input_data in concurrent_inputs]
264-
results = await asyncio.gather(*tasks)
265-
for i in range(len(concurrent_inputs)):
266-
data, response_dict = concurrent_metadata[i], results[i]
246+
247+
def sub_func(model_inputs):
248+
return self.model.generate(*model_inputs[0], **model_inputs[1])
249+
250+
results = executor.map(sub_func, concurrent_inputs)
251+
for i, result in enumerate(results):
252+
data, response_dict = concurrent_metadata[i], result
267253
self.validate_response_dict(response_dict)
268254
# prepare results for writing
269255
data.update(response_dict)

0 commit comments

Comments
 (0)