Skip to content

Commit 04cc86c

Browse files
michaelharrisonmaiSafoora YousefiMichael Harrison
authored
Remove inference batching (#111)
Removed batching, swapped some logic of the dataloader / ThreadPoolExecutor / writer in inference.py, refactored the _run methods. Current flow is: when run is called, load partial results if appropriate, then kick off ThreadPoolExecutor, which calls "_run_single" for each element of the dataloader and then appends the results to the file (note that appending vs writing is new -- let me know if this is not preferred for any reason). _run_single is a combination of the previous _run and _run_par: it checks for previous results, checks for rate limiting, and finally calls the model's generate(). --------- Co-authored-by: Safoora Yousefi <[email protected]> Co-authored-by: Michael Harrison <[email protected]>
1 parent bd1a02b commit 04cc86c

File tree

2 files changed

+55
-91
lines changed

2 files changed

+55
-91
lines changed

eureka_ml_insights/core/inference.py

+52-89
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import logging
22
import os
3+
import threading
34
import time
45
from collections import deque
56
from concurrent.futures import ThreadPoolExecutor
7+
from concurrent.futures import as_completed
68

79
from tqdm import tqdm
810

11+
from eureka_ml_insights.configs.config import DataSetConfig, ModelConfig
912
from eureka_ml_insights.data_utils.data import DataReader, JsonLinesWriter
13+
from eureka_ml_insights.models.models import Model
1014

1115
from .pipeline import Component
1216
from .reserved_names import INFERENCE_RESERVED_NAMES
@@ -17,8 +21,8 @@
1721
class Inference(Component):
1822
def __init__(
1923
self,
20-
model_config,
21-
data_config,
24+
model_config: ModelConfig,
25+
data_config: DataSetConfig,
2226
output_dir,
2327
resume_from=None,
2428
new_columns=None,
@@ -39,14 +43,16 @@ def __init__(
3943
chat_mode (bool): optional. If True, the model will be used in chat mode, where a history of messages will be maintained in "previous_messages" column.
4044
"""
4145
super().__init__(output_dir)
42-
self.model = model_config.class_name(**model_config.init_args)
46+
self.model: Model = model_config.class_name(**model_config.init_args)
4347
self.data_loader = data_config.class_name(**data_config.init_args)
44-
self.writer = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"))
48+
self.appender = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"), mode="a")
4549

4650
self.resume_from = resume_from
4751
if resume_from and not os.path.exists(resume_from):
4852
raise FileNotFoundError(f"File {resume_from} not found.")
4953
self.new_columns = new_columns
54+
self.pre_inf_results_df = None
55+
self.last_uid = None
5056

5157
# rate limiting parameters
5258
self.requests_per_minute = requests_per_minute
@@ -57,6 +63,8 @@ def __init__(
5763
self.max_concurrent = max_concurrent
5864
self.chat_mode = chat_mode
5965
self.model.chat_mode = self.chat_mode
66+
self.output_dir = output_dir
67+
self.writer_lock = threading.Lock()
6068

6169
@classmethod
6270
def from_config(cls, config):
@@ -168,89 +176,44 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
168176
return data
169177

170178
def run(self):
171-
if self.max_concurrent > 1:
172-
self._run_par()
173-
else:
174-
self._run()
175-
176-
def _run(self):
177-
"""sequential inference"""
178179
if self.resume_from:
179-
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
180-
with self.data_loader as loader:
181-
with self.writer as writer:
182-
for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"):
183-
if self.chat_mode and data.get("is_valid", True) is False:
184-
continue
185-
if self.resume_from and (data["uid"] <= last_uid):
186-
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
187-
if prev_result:
188-
writer.write(prev_result)
189-
continue
190-
191-
# generate text from model (optionally at a limited rate)
192-
if self.requests_per_minute:
193-
while len(self.request_times) >= self.requests_per_minute:
194-
# remove the oldest request time if it is older than the rate limit period
195-
if time.time() - self.request_times[0] > self.period:
196-
self.request_times.popleft()
197-
else:
198-
# rate limit is reached, wait for a second
199-
time.sleep(1)
200-
self.request_times.append(time.time())
201-
response_dict = self.model.generate(*model_args, **model_kwargs)
202-
self.validate_response_dict(response_dict)
203-
# write results
204-
data.update(response_dict)
205-
writer.write(data)
206-
207-
def _run_par(self):
208-
"""parallel inference"""
209-
concurrent_inputs = []
210-
concurrent_metadata = []
211-
if self.resume_from:
212-
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
213-
with self.data_loader as loader:
214-
with self.writer as writer:
215-
for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"):
216-
if self.chat_mode and data.get("is_valid", True) is False:
217-
continue
218-
if self.resume_from and (data["uid"] <= last_uid):
219-
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
220-
if prev_result:
221-
writer.write(prev_result)
222-
continue
223-
224-
# if batch is ready for concurrent inference
225-
elif len(concurrent_inputs) >= self.max_concurrent:
226-
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
227-
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
228-
concurrent_inputs = []
229-
concurrent_metadata = []
230-
# add data to batch for concurrent inference
231-
concurrent_inputs.append((model_args, model_kwargs))
232-
concurrent_metadata.append(data)
233-
# if data loader is exhausted but there are remaining data points that did not form a full batch
234-
if concurrent_inputs:
235-
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
236-
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
237-
238-
def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
239-
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
240-
args:
241-
concurrent_inputs (list): list of inputs to the model.generate function.
242-
concurrent_metadata (list): list of metadata corresponding to the inputs.
243-
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
244-
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
245-
"""
246-
247-
def sub_func(model_inputs):
248-
return self.model.generate(*model_inputs[0], **model_inputs[1])
249-
250-
results = executor.map(sub_func, concurrent_inputs)
251-
for i, result in enumerate(results):
252-
data, response_dict = concurrent_metadata[i], result
253-
self.validate_response_dict(response_dict)
254-
# prepare results for writing
255-
data.update(response_dict)
256-
writer.write(data)
180+
self.pre_inf_results_df, self.last_uid = self.fetch_previous_inference_results()
181+
with self.data_loader as loader, ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
182+
futures = [executor.submit(self._run_single, record) for record in loader]
183+
for future in tqdm(as_completed(futures), total=len(loader), mininterval=2.0, desc="Inference Progress: "):
184+
result = future.result()
185+
if result:
186+
self._append_threadsafe(result)
187+
188+
def _append_threadsafe(self, data):
189+
with self.writer_lock:
190+
with self.appender as appender:
191+
appender.write(data)
192+
193+
def _run_single(self, record: tuple[dict, tuple, dict]):
194+
"""Runs model.generate() with respect to a single element of the dataloader."""
195+
196+
data, model_args, model_kwargs = record
197+
if self.chat_mode and data.get("is_valid", True) is False:
198+
return None
199+
if self.resume_from and (data["uid"] <= self.last_uid):
200+
prev_result = self.retrieve_exisiting_result(data, self.pre_inf_results_df)
201+
if prev_result:
202+
return prev_result
203+
204+
# Rate limiter -- only for sequential inference
205+
if self.requests_per_minute and self.max_concurrent == 1:
206+
while len(self.request_times) >= self.requests_per_minute:
207+
# remove the oldest request time if it is older than the rate limit period
208+
if time.time() - self.request_times[0] > self.period:
209+
self.request_times.popleft()
210+
else:
211+
# rate limit is reached, wait for a second
212+
time.sleep(1)
213+
self.request_times.append(time.time())
214+
215+
response_dict = self.model.generate(*model_args, **model_kwargs)
216+
self.validate_response_dict(response_dict)
217+
data.update(response_dict)
218+
return data
219+

eureka_ml_insights/data_utils/data.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,17 @@ def load_image(self, image_file_name):
282282

283283

284284
class JsonLinesWriter:
285-
def __init__(self, out_path):
285+
def __init__(self, out_path, mode="w"):
286286
self.out_path = out_path
287287
# if the directory does not exist, create it
288288
directory = os.path.dirname(out_path)
289289
if not os.path.exists(directory):
290290
os.makedirs(directory)
291291
self.writer = None
292+
self.mode = mode
292293

293294
def __enter__(self):
294-
self.writer = jsonlines.open(self.out_path, mode="w", dumps=NumpyEncoder().encode)
295+
self.writer = jsonlines.open(self.out_path, mode=self.mode, dumps=NumpyEncoder().encode)
295296
return self.writer
296297

297298
def __exit__(self, exc_type, exc_value, traceback):

0 commit comments

Comments
 (0)