diff --git a/eureka_ml_insights/core/inference.py b/eureka_ml_insights/core/inference.py index 1a3bf19..d99e92b 100644 --- a/eureka_ml_insights/core/inference.py +++ b/eureka_ml_insights/core/inference.py @@ -1,12 +1,16 @@ import logging import os +import threading import time from collections import deque from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed from tqdm import tqdm +from eureka_ml_insights.configs.config import DataSetConfig, ModelConfig from eureka_ml_insights.data_utils.data import DataReader, JsonLinesWriter +from eureka_ml_insights.models.models import Model from .pipeline import Component from .reserved_names import INFERENCE_RESERVED_NAMES @@ -17,8 +21,8 @@ class Inference(Component): def __init__( self, - model_config, - data_config, + model_config: ModelConfig, + data_config: DataSetConfig, output_dir, resume_from=None, new_columns=None, @@ -39,14 +43,16 @@ def __init__( chat_mode (bool): optional. If True, the model will be used in chat mode, where a history of messages will be maintained in "previous_messages" column. """ super().__init__(output_dir) - self.model = model_config.class_name(**model_config.init_args) + self.model: Model = model_config.class_name(**model_config.init_args) self.data_loader = data_config.class_name(**data_config.init_args) - self.writer = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl")) + self.appender = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"), mode="a") self.resume_from = resume_from if resume_from and not os.path.exists(resume_from): raise FileNotFoundError(f"File {resume_from} not found.") self.new_columns = new_columns + self.pre_inf_results_df = None + self.last_uid = None # rate limiting parameters self.requests_per_minute = requests_per_minute @@ -57,6 +63,8 @@ def __init__( self.max_concurrent = max_concurrent self.chat_mode = chat_mode self.model.chat_mode = self.chat_mode + self.output_dir = output_dir + self.writer_lock = threading.Lock() @classmethod def from_config(cls, config): @@ -168,89 +176,44 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df): return data def run(self): - if self.max_concurrent > 1: - self._run_par() - else: - self._run() - - def _run(self): - """sequential inference""" if self.resume_from: - pre_inf_results_df, last_uid = self.fetch_previous_inference_results() - with self.data_loader as loader: - with self.writer as writer: - for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"): - if self.chat_mode and data.get("is_valid", True) is False: - continue - if self.resume_from and (data["uid"] <= last_uid): - prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df) - if prev_result: - writer.write(prev_result) - continue - - # generate text from model (optionally at a limited rate) - if self.requests_per_minute: - while len(self.request_times) >= self.requests_per_minute: - # remove the oldest request time if it is older than the rate limit period - if time.time() - self.request_times[0] > self.period: - self.request_times.popleft() - else: - # rate limit is reached, wait for a second - time.sleep(1) - self.request_times.append(time.time()) - response_dict = self.model.generate(*model_args, **model_kwargs) - self.validate_response_dict(response_dict) - # write results - data.update(response_dict) - writer.write(data) - - def _run_par(self): - """parallel inference""" - concurrent_inputs = [] - concurrent_metadata = [] - if self.resume_from: - pre_inf_results_df, last_uid = self.fetch_previous_inference_results() - with self.data_loader as loader: - with self.writer as writer: - for data, model_args, model_kwargs in tqdm(loader, desc="Inference Progress:"): - if self.chat_mode and data.get("is_valid", True) is False: - continue - if self.resume_from and (data["uid"] <= last_uid): - prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df) - if prev_result: - writer.write(prev_result) - continue - - # if batch is ready for concurrent inference - elif len(concurrent_inputs) >= self.max_concurrent: - with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: - self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor) - concurrent_inputs = [] - concurrent_metadata = [] - # add data to batch for concurrent inference - concurrent_inputs.append((model_args, model_kwargs)) - concurrent_metadata.append(data) - # if data loader is exhausted but there are remaining data points that did not form a full batch - if concurrent_inputs: - with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: - self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor) - - def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor): - """Run a batch of inferences concurrently using ThreadPoolExecutor. - args: - concurrent_inputs (list): list of inputs to the model.generate function. - concurrent_metadata (list): list of metadata corresponding to the inputs. - writer (JsonLinesWriter): JsonLinesWriter instance to write the results. - executor (ThreadPoolExecutor): ThreadPoolExecutor instance. - """ - - def sub_func(model_inputs): - return self.model.generate(*model_inputs[0], **model_inputs[1]) - - results = executor.map(sub_func, concurrent_inputs) - for i, result in enumerate(results): - data, response_dict = concurrent_metadata[i], result - self.validate_response_dict(response_dict) - # prepare results for writing - data.update(response_dict) - writer.write(data) + self.pre_inf_results_df, self.last_uid = self.fetch_previous_inference_results() + with self.data_loader as loader, ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: + futures = [executor.submit(self._run_single, record) for record in loader] + for future in tqdm(as_completed(futures), total=len(loader), mininterval=2.0, desc="Inference Progress: "): + result = future.result() + if result: + self._append_threadsafe(result) + + def _append_threadsafe(self, data): + with self.writer_lock: + with self.appender as appender: + appender.write(data) + + def _run_single(self, record: tuple[dict, tuple, dict]): + """Runs model.generate() with respect to a single element of the dataloader.""" + + data, model_args, model_kwargs = record + if self.chat_mode and data.get("is_valid", True) is False: + return None + if self.resume_from and (data["uid"] <= self.last_uid): + prev_result = self.retrieve_exisiting_result(data, self.pre_inf_results_df) + if prev_result: + return prev_result + + # Rate limiter -- only for sequential inference + if self.requests_per_minute and self.max_concurrent == 1: + while len(self.request_times) >= self.requests_per_minute: + # remove the oldest request time if it is older than the rate limit period + if time.time() - self.request_times[0] > self.period: + self.request_times.popleft() + else: + # rate limit is reached, wait for a second + time.sleep(1) + self.request_times.append(time.time()) + + response_dict = self.model.generate(*model_args, **model_kwargs) + self.validate_response_dict(response_dict) + data.update(response_dict) + return data + \ No newline at end of file diff --git a/eureka_ml_insights/data_utils/data.py b/eureka_ml_insights/data_utils/data.py index 3a4bd1d..b26a04d 100644 --- a/eureka_ml_insights/data_utils/data.py +++ b/eureka_ml_insights/data_utils/data.py @@ -282,16 +282,17 @@ def load_image(self, image_file_name): class JsonLinesWriter: - def __init__(self, out_path): + def __init__(self, out_path, mode="w"): self.out_path = out_path # if the directory does not exist, create it directory = os.path.dirname(out_path) if not os.path.exists(directory): os.makedirs(directory) self.writer = None + self.mode = mode def __enter__(self): - self.writer = jsonlines.open(self.out_path, mode="w", dumps=NumpyEncoder().encode) + self.writer = jsonlines.open(self.out_path, mode=self.mode, dumps=NumpyEncoder().encode) return self.writer def __exit__(self, exc_type, exc_value, traceback):