Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove inference batching #111

Merged
merged 23 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bc915db
readme updates
Dec 11, 2024
e290211
fix links
Dec 12, 2024
72b9b9e
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 8, 2025
3a20480
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 17, 2025
f58154e
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 22, 2025
9b9c5c0
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 24, 2025
ce1b2fe
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 29, 2025
b2a8376
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 29, 2025
645eefa
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Jan 31, 2025
89daef5
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Feb 11, 2025
d9988df
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Feb 27, 2025
12dffdf
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Mar 1, 2025
173538a
resolve concurrency issues
Mar 11, 2025
d77dfad
formatting
Mar 12, 2025
382f452
thread safety
Mar 13, 2025
8fa0931
Merge branch 'main' of https://github.com/microsoft/eureka-ml-insights
Mar 14, 2025
03c848c
merge with main
Mar 14, 2025
5c790b1
bug fixes
Mar 15, 2025
2c45d8b
revert to single model uinstance
Mar 15, 2025
1751905
llava model thread safety
Mar 15, 2025
5d391b7
refactored inference.py to not batch requests
Mar 15, 2025
e397fe3
Merge branch 'main' into mharrison/remove-inference-batching
michaelharrisonmai Mar 15, 2025
114e446
remove unused vars
Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 52 additions & 89 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
import os
import threading
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed

from tqdm import tqdm

from eureka_ml_insights.configs.config import DataSetConfig, ModelConfig
from eureka_ml_insights.data_utils.data import DataReader, JsonLinesWriter
from eureka_ml_insights.models.models import Model

from .pipeline import Component
from .reserved_names import INFERENCE_RESERVED_NAMES
Expand All @@ -17,8 +21,8 @@
class Inference(Component):
def __init__(
self,
model_config,
data_config,
model_config: ModelConfig,
data_config: DataSetConfig,
output_dir,
resume_from=None,
new_columns=None,
Expand All @@ -39,14 +43,16 @@ def __init__(
chat_mode (bool): optional. If True, the model will be used in chat mode, where a history of messages will be maintained in "previous_messages" column.
"""
super().__init__(output_dir)
self.model = model_config.class_name(**model_config.init_args)
self.model: Model = model_config.class_name(**model_config.init_args)
self.data_loader = data_config.class_name(**data_config.init_args)
self.writer = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"))
self.appender = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"), mode="a")

self.resume_from = resume_from
if resume_from and not os.path.exists(resume_from):
raise FileNotFoundError(f"File {resume_from} not found.")
self.new_columns = new_columns
self.pre_inf_results_df = None
self.last_uid = None

# rate limiting parameters
self.requests_per_minute = requests_per_minute
Expand All @@ -57,6 +63,8 @@ def __init__(
self.max_concurrent = max_concurrent
self.chat_mode = chat_mode
self.model.chat_mode = self.chat_mode
self.output_dir = output_dir
self.writer_lock = threading.Lock()

@classmethod
def from_config(cls, config):
Expand Down Expand Up @@ -168,89 +176,44 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
return data

def run(self):
if self.max_concurrent > 1:
self._run_par()
else:
self._run()

def _run(self):
"""sequential inference"""
if self.resume_from:
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader:
with self.writer as writer:
for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"):
if self.chat_mode and data.get("is_valid", True) is False:
continue
if self.resume_from and (data["uid"] <= last_uid):
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
if prev_result:
writer.write(prev_result)
continue

# generate text from model (optionally at a limited rate)
if self.requests_per_minute:
while len(self.request_times) >= self.requests_per_minute:
# remove the oldest request time if it is older than the rate limit period
if time.time() - self.request_times[0] > self.period:
self.request_times.popleft()
else:
# rate limit is reached, wait for a second
time.sleep(1)
self.request_times.append(time.time())
response_dict = self.model.generate(*model_args, **model_kwargs)
self.validate_response_dict(response_dict)
# write results
data.update(response_dict)
writer.write(data)

def _run_par(self):
"""parallel inference"""
concurrent_inputs = []
concurrent_metadata = []
if self.resume_from:
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader:
with self.writer as writer:
for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"):
if self.chat_mode and data.get("is_valid", True) is False:
continue
if self.resume_from and (data["uid"] <= last_uid):
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
if prev_result:
writer.write(prev_result)
continue

# if batch is ready for concurrent inference
elif len(concurrent_inputs) >= self.max_concurrent:
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
concurrent_inputs = []
concurrent_metadata = []
# add data to batch for concurrent inference
concurrent_inputs.append((model_args, model_kwargs))
concurrent_metadata.append(data)
# if data loader is exhausted but there are remaining data points that did not form a full batch
if concurrent_inputs:
with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)

def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
args:
concurrent_inputs (list): list of inputs to the model.generate function.
concurrent_metadata (list): list of metadata corresponding to the inputs.
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
"""

def sub_func(model_inputs):
return self.model.generate(*model_inputs[0], **model_inputs[1])

results = executor.map(sub_func, concurrent_inputs)
for i, result in enumerate(results):
data, response_dict = concurrent_metadata[i], result
self.validate_response_dict(response_dict)
# prepare results for writing
data.update(response_dict)
writer.write(data)
self.pre_inf_results_df, self.last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader, ThreadPoolExecutor(max_workers=self.max_concurrent) as executor:
futures = [executor.submit(self._run_single, record) for record in loader]
for future in tqdm(as_completed(futures), total=len(loader), mininterval=2.0, desc="Inference Progress: "):
result = future.result()
if result:
self._append_threadsafe(result)

def _append_threadsafe(self, data):
with self.writer_lock:
with self.appender as appender:
appender.write(data)

def _run_single(self, record: tuple[dict, tuple, dict]):
"""Runs model.generate() with respect to a single element of the dataloader."""

data, model_args, model_kwargs = record
if self.chat_mode and data.get("is_valid", True) is False:
return None
if self.resume_from and (data["uid"] <= self.last_uid):
prev_result = self.retrieve_exisiting_result(data, self.pre_inf_results_df)
if prev_result:
return prev_result

# Rate limiter -- only for sequential inference
if self.requests_per_minute and self.max_concurrent == 1:
while len(self.request_times) >= self.requests_per_minute:
# remove the oldest request time if it is older than the rate limit period
if time.time() - self.request_times[0] > self.period:
self.request_times.popleft()
else:
# rate limit is reached, wait for a second
time.sleep(1)
self.request_times.append(time.time())

response_dict = self.model.generate(*model_args, **model_kwargs)
self.validate_response_dict(response_dict)
data.update(response_dict)
return data

5 changes: 3 additions & 2 deletions eureka_ml_insights/data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,17 @@ def load_image(self, image_file_name):


class JsonLinesWriter:
def __init__(self, out_path):
def __init__(self, out_path, mode="w"):
self.out_path = out_path
# if the directory does not exist, create it
directory = os.path.dirname(out_path)
if not os.path.exists(directory):
os.makedirs(directory)
self.writer = None
self.mode = mode

def __enter__(self):
self.writer = jsonlines.open(self.out_path, mode="w", dumps=NumpyEncoder().encode)
self.writer = jsonlines.open(self.out_path, mode=self.mode, dumps=NumpyEncoder().encode)
return self.writer

def __exit__(self, exc_type, exc_value, traceback):
Expand Down
Loading