From 56e925c353904f58403e7abcd763765bf9c7cf00 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Wed, 11 Mar 2026 04:59:02 +0100 Subject: [PATCH 01/45] Draft : implementation OpenAI support --- .../processors/llm/api_batch_client.py | 83 +++++ .../core/process/processors/llm/api_utils.py | 118 +++++++ .../processors/llm/claude_batch_client.py | 18 + .../core/process/processors/llm/config.py | 5 + .../process/processors/llm/llm_processor.py | 28 +- .../processors/llm/openai_batch_client.py | 333 ++++++++++++++++++ 6 files changed, 582 insertions(+), 3 deletions(-) create mode 100644 src/mmirage/core/process/processors/llm/api_batch_client.py create mode 100644 src/mmirage/core/process/processors/llm/api_utils.py create mode 100644 src/mmirage/core/process/processors/llm/claude_batch_client.py create mode 100644 src/mmirage/core/process/processors/llm/openai_batch_client.py diff --git a/src/mmirage/core/process/processors/llm/api_batch_client.py b/src/mmirage/core/process/processors/llm/api_batch_client.py new file mode 100644 index 0000000..5b88d09 --- /dev/null +++ b/src/mmirage/core/process/processors/llm/api_batch_client.py @@ -0,0 +1,83 @@ + + +from typing import List, Optional, Dict, Any, Type +from pydantic import BaseModel +from pathlib import Path + + +from mmirage.core.process.variables import VariableEnvironment +from abc import ABC, abstractmethod + +class APIBatchClient(ABC): + + def __init__(self, model_name: str, api_key: str, provider: str): + self.model_name = model_name + self.api_key = api_key + self.provider = provider + + @abstractmethod + def build_request( + self, + *, + prompt: str, + image_b64: str = None, + media_type: str = None, + request_id: int, + system_prompt: str = None, + output_schema: Optional[Type[BaseModel]] = None, + ) -> dict: + """ + Build a single API request object based on the provider. + + Args: + text: The input text to send to the LLM. + image_b64: Optional base64-encoded image string for multimodal models. + request_id: Unique identifier for this request. + + Returns: + A dict representing the API request payload. + """ + pass + + + @abstractmethod + def submit_batches(self, output_dir: Path) -> None: + """ + Submit batches of requests to the LLM API and save responses. + + Args: + batches_dir: Directory containing batch request files. + output_dir: Directory to save API responses. + """ + pass + + + + @abstractmethod + def process_dataset(self, + *, + nb_samples: Optional[int] = None, + ) -> None: + """ + Build batch JSONL files + + Writes one or more files: part_1.jsonl, part_2.jsonl, ... + Splits by MAX_PART_SIZE_BYTES. + """ + + + + @abstractmethod + def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) -> None: + """ + Wait for API responses and collect outputs into VariableEnvironments. + + Args: + batches_dir: Directory containing batch request files. + output_dir: Directory where API responses are saved. + """ + pass + + + + diff --git a/src/mmirage/core/process/processors/llm/api_utils.py b/src/mmirage/core/process/processors/llm/api_utils.py new file mode 100644 index 0000000..5eae6c4 --- /dev/null +++ b/src/mmirage/core/process/processors/llm/api_utils.py @@ -0,0 +1,118 @@ +from pathlib import Path +import base64, json +from typing import List, Optional, Tuple + +import tqdm + + +def encode_image_to_base64(path: Path) -> str: + """ + Read an image from disk and return base64-encoded string. + """ + if not path.exists(): + raise FileNotFoundError(f"Image not found: {path}") + + with path.open("rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + +def get_media_type(path: Path) -> str: + """ + Get the media type (MIME type) of a file based on its extension. + """ + ext = path.suffix.lower() + if ext in [".jpg", ".jpeg"]: + return "image/jpeg" + elif ext == ".png": + return "image/png" + elif ext == ".webp": + return "image/webp" + else: + raise ValueError(f"Unsupported file extension: {ext}") + + + +def load_data_raw(manifest_path: Path) -> List[dict]: + """ + Load the raw JSONL manifest as a list of dicts. + """ + if not manifest_path.exists(): + raise FileNotFoundError( + f"Manifest not found: {manifest_path}" + ) + + records: List[dict] = [] + with manifest_path.open("r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON on line {line_num} of {manifest_path}" + ) from e + + if not records: + raise RuntimeError("Manifest loaded but contains no records.") + + return records + + +def resolve_image_path(image_root: Path, value: str) -> Path: + """ + Resolve image paths safely, handling leading slashes. + """ + rel = value.lstrip("/") + return image_root / rel + + + +def load_data( + nb_samples: Optional[int] = None, + max_images_per_sample: int = 1, + ) -> Tuple[List[Tuple[str, Tuple[str, ...]]], List[str]]: + """ + Load dataset examples and encode images. + + Returns: + examples: List of (text, (img_b64, ...)) + paths: List of absolute image paths used + """ + raw_records = load_data_raw() + records = raw_records[:nb_samples] if nb_samples else raw_records + + examples: List[Tuple[str, Tuple[str, ...]]] = [] + used_paths: List[str] = [] + + for rec in tqdm(records, desc="Loading dataset"): + text = str(rec.get("text", "")).strip() + if not text: + continue + + image_paths: List[Path] = [] + for m in rec.get("modalities", []): + if m.get("type") == "image" and m.get("value"): + image_paths.append(resolve_image_path(m["value"])) + if len(image_paths) >= max_images_per_sample: + break + + if not image_paths: + continue + + try: + encoded_images = tuple( + encode_image_to_base64(p) for p in image_paths + ) + except Exception as e: + # Skip corrupted or unreadable images + print(f"[WARN] Skipping sample due to image error: {e}") + continue + + examples.append((text, encoded_images)) + used_paths.extend(str(p) for p in image_paths) + + if not examples: + raise RuntimeError("No valid examples loaded.") + + return examples, used_paths diff --git a/src/mmirage/core/process/processors/llm/claude_batch_client.py b/src/mmirage/core/process/processors/llm/claude_batch_client.py new file mode 100644 index 0000000..ebd88b6 --- /dev/null +++ b/src/mmirage/core/process/processors/llm/claude_batch_client.py @@ -0,0 +1,18 @@ +from pathlib import Path + +import anthropic + +from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient + + + +class AnthropicBatchClient(APIBatchClient): + def __init__(self, model_name: str, api_key: str, output_dir: Path): + super().__init__(model_name=model_name, api_key=api_key, provider="anthropic") + + if not self.api_key: + raise SystemExit( + "ANTHROPIC_API_KEY is not set. Please export it before running." + ) + self.client = anthropic.Anthropic(api_key=self.api_key) + self.output_dir = output_dir \ No newline at end of file diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 4c195af..fe99ea2 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -49,6 +49,11 @@ class SGLangLLMConfig(BaseProcessorConfig): default_sampling_params: Dict[str, Any] = field(default_factory=dict) chat_template: str = "" # Empty means use tokenizer's default + provider: str = "sglang" # options: "sglang", "anthropic", "openai". Used for routing to the correct LLM provider + api_model_name: str = "gpt-4o" # model name to use when provider is API-based (e.g., OpenAI, Anthropic) + api_key: str = "" # API key for API-based providers + out + @dataclass class LLMOutputVar(OutputVar): diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 56afd17..91cbf2b 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -13,6 +13,7 @@ from mmirage.core.process.base import BaseProcessor, ProcessorRegistry from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig +from mmirage.core.process.processors.llm.openai_batch_client import OpenAIBatchClient from mmirage.core.process.variables import VariableEnvironment try: @@ -57,7 +58,18 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: engine_args: Configuration for SGLang server and sampling parameters. **kwargs: Additional arguments passed to base class. """ + + + + super().__init__(engine_args, **kwargs) + + if self.engine_args.provider == "openai": + self.llm = OpenAIBatchClient(self.engine_args.api_model_name, self.engine_args.api_key) + elif self.engine_args.provider == "anthropic": + pass + + # Default to SGLang Engine self.llm = sgl.Engine(**asdict(engine_args.server_args)) self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, @@ -130,7 +142,7 @@ def _get_image_token(self) -> str: return IMAGE_TOKENS.get(self.chat_template, "") @override - def batch_process_sample( + def batch_process_sample( self, batch: List[VariableEnvironment], output_var: LLMOutputVar ) -> List[VariableEnvironment]: """Process a batch of variable environments to generate LLM outputs. @@ -147,6 +159,18 @@ def batch_process_sample( RuntimeError: If output batch size doesn't match input batch size. """ nb_samples = len(batch) + results: dict[int, VariableEnvironment] = {} + + # ---- For API-based providers ---- + if self.provider in ["openai", "anthropic"]: + + prompts = self.llm.build_prompt(output_var.prompt, batch) + requests_payloads = [self.llm.build_request(text=prompts[i], request_id=i) for i in range(nb_samples)] + responses = self.llm.submit_and_wait(requests_payloads) + + + + # ---- For SGLang Engine provider ---- # Prepare sampling params sampling_params_output = self.sampling_params.copy() @@ -170,8 +194,6 @@ def batch_process_sample( else: text_only_indices.append(i) - results: dict[int, VariableEnvironment] = {} - # Text-only batch if text_only_indices: text_only_envs = [batch[i] for i in text_only_indices] diff --git a/src/mmirage/core/process/processors/llm/openai_batch_client.py b/src/mmirage/core/process/processors/llm/openai_batch_client.py new file mode 100644 index 0000000..5cc7ea5 --- /dev/null +++ b/src/mmirage/core/process/processors/llm/openai_batch_client.py @@ -0,0 +1,333 @@ +from openai import OpenAI +import json, time, re +from tqdm import tqdm +from typing import List, Tuple, Optional, Type +import json +from pathlib import Path +from pydantic import BaseModel + +from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient +from mmirage.core.process.processors.llm.api_utils import get_media_type, load_data + + + +# --------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------- + +MAX_TOKENS = 1000 +# 50 MB per batch part is well below OpenAI limits and avoids failures +MAX_PART_SIZE_BYTES = int(0.05 * 1024 ** 3) + + + + +class OpenAIBatchClient(APIBatchClient): + def __init__(self, model_name: str, api_key: str, output_dir: Path): + super().__init__(model_name=model_name, api_key=api_key, provider="openai") + + if not self.api_key: + raise SystemExit( + "OPENAI_API_KEY is not set. Please export it before running." + ) + self.client = OpenAI(api_key=self.api_key) + self.output_dir = output_dir + + + # --------------------------------------------------------------------- + # Request builder + # --------------------------------------------------------------------- + + + def build_request( + self, + *, + prompt: str, + image_b64: str = None, + media_type: str = None, + request_id: int, + system_prompt: str = None, + output_schema: Optional[Type[BaseModel]] = None, + ) -> dict: + """Build a single OpenAI Batch API request object. + + Args: + prompt: The fully-rendered user prompt (Jinja2 already applied). + image_b64: Optional base64-encoded image for multimodal requests. + media_type: MIME type of the image (e.g., "image/jpeg"). + request_id: Unique identifier used as custom_id. + system_prompt: Optional system message prepended to the conversation. + output_schema: Optional Pydantic model used to enforce a JSON response + via OpenAI's structured-output ``response_format``. + """ + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + if image_b64 is not None and media_type is not None: + user_content = [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:{media_type};base64,{image_b64}"}, + }, + ] + else: + user_content = prompt + + messages.append({"role": "user", "content": user_content}) + + body: dict = { + "model": self.model_name, + "messages": messages, + "max_tokens": MAX_TOKENS, + } + + if output_schema is not None: + body["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": output_schema.__name__, + "strict": True, + "schema": output_schema.model_json_schema(), + }, + } + + return { + "custom_id": f"request-{request_id}", + "method": "POST", + "url": "/v1/chat/completions", + "body": body, + } + + + + + # --------------------------------------------------------------------- + # Batch construction TODO + # --------------------------------------------------------------------- + + + + + def process_dataset(self, + *, + nb_samples: Optional[int] = None, + ) -> None: + """ + Build batch JSONL files for OpenAI Batch API. + + Writes one or more files: part_1.jsonl, part_2.jsonl, ... + Splits by MAX_PART_SIZE_BYTES. + """ + self.output_dir.mkdir(parents=True, exist_ok=True) + + examples, _ = load_data(nb_samples=nb_samples) + if not examples: + raise RuntimeError("No examples loaded; aborting batch creation.") + + mode = "ALL" if nb_samples is None else f"{nb_samples}" + print(f"[INFO] NB_SAMPLES={mode} | Output={self.output_dir}") + print(f"[INFO] Building batches for {len(examples)} samples") + + part_idx = 1 + bytes_in_part = 0 + + part_path = self.output_dir / f"part_{part_idx}.jsonl" + part_file = part_path.open("w", encoding="utf-8") + + for i, (text, encoded_images) in tqdm( + enumerate(examples, start=1), + total=len(examples), + desc="Building batch requests", + ): + if not encoded_images: + continue + + # Enforce one image per request + image_b64 = encoded_images[0] + media_type = get_media_type(Path(encoded_images[0])) + req = self.build_request( + prompt=text, + image_b64=image_b64, + media_type=media_type, + request_id=i, + system_prompt=None,# TODO + output_schema=None,# TODO + ) + + line = json.dumps(req, ensure_ascii=False) + "\n" + size = len(line.encode("utf-8")) + + if bytes_in_part + size > MAX_PART_SIZE_BYTES: + part_file.close() + part_idx += 1 + bytes_in_part = 0 + part_path = self.output_dir / f"part_{part_idx}.jsonl" + part_file = part_path.open("w", encoding="utf-8") + + part_file.write(line) + bytes_in_part += size + + part_file.close() + print(f"[DONE] Created {part_idx} batch file(s) in {self.output_dir}") + + + + + + # --------------------------------------------------------------------- + # Batch submission + # --------------------------------------------------------------------- + + + def submit_batches(self, batches_dir: Path) -> None: + """Submit batch files to OpenAI Batch API.""" + + parts = sorted(batches_dir.glob("*.jsonl")) + + if not parts: + raise SystemExit(f"No batch files found in {batches_dir}") + + + # submit batches sequentially with progress bar; persist batch IDs for reproducibility + for part in tqdm(parts, desc="Submitting batches"): + batch_id_file = self.output_dir / f"batch_id_{part.name}.txt" + + # Skip if already submitted + if batch_id_file.exists(): + print(f"[SKIP] {part.name} already submitted") + continue + + # Upload batch input file + with part.open("rb") as fh: + uploaded = self.client.files.create( + file=fh, + purpose="batch", + ) + + # Create batch job + batch = self.client.batches.create( + input_file_id=uploaded.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "description": f"Dataset augmentation - {part.name}", + }, + ) + + # Persist batch ID (critical for reproducibility) + batch_id_file.write_text(batch.id) + + + print(f"[SUBMITTED] {part.name} → batch_id={batch.id}") + + print("All batches submitted.") + + + + + + # --------------------------------------------------------------------- + # Collect Outputs + # --------------------------------------------------------------------- + + + def __wait_for_output(self, batch_id: str, max_wait_s: int = 86400, poll_s: int = 30): + waited = 0 + while True: + b = self.client.batches.retrieve(batch_id) + print(f"[{batch_id}] status={b.status} out={b.output_file_id} err={b.error_file_id}") + if b.output_file_id: + return b + if b.status in ("failed", "cancelled", "expired"): + raise SystemExit(f"Batch ended with status: {b.status}") + time.sleep(poll_s) + waited += poll_s + if waited >= max_wait_s: + raise SystemExit("Timed out waiting for output_file_id") + + def __part_number_from_filename(self, p: Path) -> int: + m = re.search(r"batch_id_part_(\d+)\.jsonl\.txt$", p.name) + return int(m.group(1)) if m else 0 + + def __extract_messages(api_responses: List[dict]) -> List[str]: + return [ + rec["response"]["body"]["choices"][0]["message"]["content"].strip() + for rec in api_responses + ] + + def __save_part_output(self, b, part_num: int, output_dir : Path) -> List[dict]: + text = self.client.files.content(b.output_file_id).text + part_path = output_dir / f"api_response_part_{part_num}.jsonl" + part_path.write_text(text) + print(f"[saved] {part_path}") + return [json.loads(line) for line in text.splitlines() if line.strip()] + + + def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) -> None: + """Wait for batch completions and download outputs.""" + batch_id_files = sorted(output_dir.glob("batch_id_*.txt")) + + if not batch_id_files: + raise SystemExit(f"No batch ID files found in {output_dir}") + + + all_records = [] + total_prompt = total_completion = 0 + + + for id_file in batch_id_files: + part_num = self.__part_number_from_filename(id_file) + batch_id = id_file.read_text().strip() + + b = self.__wait_for_output(batch_id) + records = self.__save_part_output(b, part_num) + all_records.extend(records) + + # accumulate actual usage + for rec in records: + usage = rec.get("response", {}).get("body", {}).get("usage", {}) + total_prompt += int(usage.get("prompt_tokens", 0)) + total_completion += int(usage.get("completion_tokens", 0)) + + + # merged outputs + all_path = output_dir / "api_response_all.jsonl" + with all_path.open("w", encoding="utf-8") as fout: + for rec in all_records: + fout.write(json.dumps(rec, ensure_ascii=False) + "\n") + print(f"[merged] {all_path} ({len(all_records)} responses)") + + + # optional: also save the plain texts + texts_path = output_dir / "messages_all.txt" + with texts_path.open("w", encoding="utf-8") as ftxt: + for msg in self.__extract_messages(all_records): + ftxt.write(msg + "\n\n") + print(f"[texts] {texts_path}") + + + + + + + + + + + + + + for batch_id_file in tqdm(batch_id_files, desc="Waiting for batches"): + batch_id = batch_id_file.read_text().strip() + batch = self.__wait_for_output(batch_id) + + # Download output file + output_path = output_dir / f"output_{batch_id}.jsonl" + with output_path.open("wb") as fh: + self.client.files.download(batch.output_file_id, fh) + + print(f"[DOWNLOADED] Batch {batch_id} output to {output_path}") + + From f99a6907d85b1b8770948f14e2297fc131188e42 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 16 Mar 2026 15:16:54 +0100 Subject: [PATCH 02/45] continuation and correction of the OpenAI API support --- .../processors/llm/api_batch_client.py | 1 + .../core/process/processors/llm/api_utils.py | 7 ++- .../core/process/processors/llm/config.py | 1 - .../process/processors/llm/llm_processor.py | 22 +++++-- .../processors/llm/openai_batch_client.py | 60 +++++++------------ 5 files changed, 44 insertions(+), 47 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/api_batch_client.py b/src/mmirage/core/process/processors/llm/api_batch_client.py index 5b88d09..2a3ac0d 100644 --- a/src/mmirage/core/process/processors/llm/api_batch_client.py +++ b/src/mmirage/core/process/processors/llm/api_batch_client.py @@ -64,6 +64,7 @@ def process_dataset(self, Writes one or more files: part_1.jsonl, part_2.jsonl, ... Splits by MAX_PART_SIZE_BYTES. """ + pass diff --git a/src/mmirage/core/process/processors/llm/api_utils.py b/src/mmirage/core/process/processors/llm/api_utils.py index 5eae6c4..540aeae 100644 --- a/src/mmirage/core/process/processors/llm/api_utils.py +++ b/src/mmirage/core/process/processors/llm/api_utils.py @@ -31,7 +31,7 @@ def get_media_type(path: Path) -> str: -def load_data_raw(manifest_path: Path) -> List[dict]: +def load_data_raw(manifest_path: Path) -> List[dict]: #TODO not useful, remove """ Load the raw JSONL manifest as a list of dicts. """ @@ -69,9 +69,10 @@ def resolve_image_path(image_root: Path, value: str) -> Path: def load_data( + manifest_path: Path, nb_samples: Optional[int] = None, max_images_per_sample: int = 1, - ) -> Tuple[List[Tuple[str, Tuple[str, ...]]], List[str]]: + ) -> Tuple[List[Tuple[str, Tuple[str, ...]]], List[str]]: #TODO Look if it is useful, if not remove """ Load dataset examples and encode images. @@ -79,7 +80,7 @@ def load_data( examples: List of (text, (img_b64, ...)) paths: List of absolute image paths used """ - raw_records = load_data_raw() + raw_records = load_data_raw(manifest_path) records = raw_records[:nb_samples] if nb_samples else raw_records examples: List[Tuple[str, Tuple[str, ...]]] = [] diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index fe99ea2..ab769f1 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -52,7 +52,6 @@ class SGLangLLMConfig(BaseProcessorConfig): provider: str = "sglang" # options: "sglang", "anthropic", "openai". Used for routing to the correct LLM provider api_model_name: str = "gpt-4o" # model name to use when provider is API-based (e.g., OpenAI, Anthropic) api_key: str = "" # API key for API-based providers - out @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 91cbf2b..44f03b1 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -5,6 +5,7 @@ from dataclasses import asdict import json import logging +from pathlib import Path from typing import Any, List, Tuple import jinja2 @@ -15,6 +16,8 @@ from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.processors.llm.openai_batch_client import OpenAIBatchClient from mmirage.core.process.variables import VariableEnvironment +from mmirage.core.process.processors.llm.api_utils import encode_image_to_base64, get_media_type + try: from typing import override # Python 3.12+ @@ -163,12 +166,21 @@ def batch_process_sample( # ---- For API-based providers ---- if self.provider in ["openai", "anthropic"]: + # dataset_examples = List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) + # prompt_text is generated with jinja rendering and images are encoded to base64 with their media types + batch_prompts: List[Tuple[str, Tuple[Tuple[str, str], ...]]] = [] + for var_env in batch: + jinja_template = jinja2.Template(output_var.prompt) + base_prompt = jinja_template.render(**var_env.to_dict()) - prompts = self.llm.build_prompt(output_var.prompt, batch) - requests_payloads = [self.llm.build_request(text=prompts[i], request_id=i) for i in range(nb_samples)] - responses = self.llm.submit_and_wait(requests_payloads) - - + image_paths = var_env.get_images() + encoded_images = tuple((encode_image_to_base64(p), get_media_type(Path(p))) for p in image_paths) if image_paths else () + batch_prompts.append((base_prompt, encoded_images)) + + self.llm.process_dataset(batch_prompts) + + self.llm.submit_batches(self.llm.output_dir, nb_samples=nb_samples) + self.llm.await_and_collect_batch_outputs(self.llm.output_dir) # ---- For SGLang Engine provider ---- diff --git a/src/mmirage/core/process/processors/llm/openai_batch_client.py b/src/mmirage/core/process/processors/llm/openai_batch_client.py index 5cc7ea5..a4b3e02 100644 --- a/src/mmirage/core/process/processors/llm/openai_batch_client.py +++ b/src/mmirage/core/process/processors/llm/openai_batch_client.py @@ -7,7 +7,9 @@ from pydantic import BaseModel from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient -from mmirage.core.process.processors.llm.api_utils import get_media_type, load_data +from mmirage.core.process.processors.llm.api_utils import get_media_type, load_data, encode_image_to_base64 +from mmirage.core.process.processors.llm.config import LLMOutputVar +from mmirage.core.process.variables import VariableEnvironment @@ -32,6 +34,7 @@ def __init__(self, model_name: str, api_key: str, output_dir: Path): ) self.client = OpenAI(api_key=self.api_key) self.output_dir = output_dir + self.batches_dir = output_dir / "batches" #TODO to implement in process_dataset and submit_batches # --------------------------------------------------------------------- @@ -111,25 +114,17 @@ def build_request( def process_dataset(self, - *, - nb_samples: Optional[int] = None, + batch: List[Tuple[str, Tuple[Tuple[str, str], ...]]], ) -> None: """ - Build batch JSONL files for OpenAI Batch API. - - Writes one or more files: part_1.jsonl, part_2.jsonl, ... + Build batch JSONL files for OpenAI Batch API. Writes one or more files: part_1.jsonl, part_2.jsonl, ... Splits by MAX_PART_SIZE_BYTES. + + Args: + batch: List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) """ self.output_dir.mkdir(parents=True, exist_ok=True) - examples, _ = load_data(nb_samples=nb_samples) - if not examples: - raise RuntimeError("No examples loaded; aborting batch creation.") - - mode = "ALL" if nb_samples is None else f"{nb_samples}" - print(f"[INFO] NB_SAMPLES={mode} | Output={self.output_dir}") - print(f"[INFO] Building batches for {len(examples)} samples") - part_idx = 1 bytes_in_part = 0 @@ -137,18 +132,16 @@ def process_dataset(self, part_file = part_path.open("w", encoding="utf-8") for i, (text, encoded_images) in tqdm( - enumerate(examples, start=1), - total=len(examples), + enumerate(batch, start=1), + total=len(batch), desc="Building batch requests", ): - if not encoded_images: - continue - # Enforce one image per request - image_b64 = encoded_images[0] - media_type = get_media_type(Path(encoded_images[0])) + # Enforce one image per request TODO : allow more than one image + image_b64 = encoded_images[0][0] + media_type = encoded_images[0][1] req = self.build_request( - prompt=text, + prompt=text, image_b64=image_b64, media_type=media_type, request_id=i, @@ -222,7 +215,7 @@ def submit_batches(self, batches_dir: Path) -> None: print(f"[SUBMITTED] {part.name} → batch_id={batch.id}") - print("All batches submitted.") + print("All batches submitted.") @@ -265,12 +258,12 @@ def __save_part_output(self, b, part_num: int, output_dir : Path) -> List[dict]: return [json.loads(line) for line in text.splitlines() if line.strip()] - def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) -> None: + def await_and_collect_batch_outputs(self) -> None: """Wait for batch completions and download outputs.""" - batch_id_files = sorted(output_dir.glob("batch_id_*.txt")) + batch_id_files = sorted(self.output_dir.glob("batch_id_*.txt")) if not batch_id_files: - raise SystemExit(f"No batch ID files found in {output_dir}") + raise SystemExit(f"No batch ID files found in {self.output_dir}") all_records = [] @@ -293,7 +286,7 @@ def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) - # merged outputs - all_path = output_dir / "api_response_all.jsonl" + all_path = self.output_dir / "api_response_all.jsonl" with all_path.open("w", encoding="utf-8") as fout: for rec in all_records: fout.write(json.dumps(rec, ensure_ascii=False) + "\n") @@ -301,7 +294,7 @@ def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) - # optional: also save the plain texts - texts_path = output_dir / "messages_all.txt" + texts_path = self.output_dir / "messages_all.txt" with texts_path.open("w", encoding="utf-8") as ftxt: for msg in self.__extract_messages(all_records): ftxt.write(msg + "\n\n") @@ -310,21 +303,12 @@ def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) - - - - - - - - - - for batch_id_file in tqdm(batch_id_files, desc="Waiting for batches"): batch_id = batch_id_file.read_text().strip() batch = self.__wait_for_output(batch_id) # Download output file - output_path = output_dir / f"output_{batch_id}.jsonl" + output_path = self.output_dir / f"output_{batch_id}.jsonl" with output_path.open("wb") as fh: self.client.files.download(batch.output_file_id, fh) From 7189b3e2fb489162f4d8d760b12836dcad8022fd Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 16 Mar 2026 16:03:18 +0100 Subject: [PATCH 03/45] cleaning --- .../core/process/processors/llm/api_utils.py | 87 ------------------- .../processors/llm/openai_batch_client.py | 1 - 2 files changed, 88 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/api_utils.py b/src/mmirage/core/process/processors/llm/api_utils.py index 540aeae..e13d359 100644 --- a/src/mmirage/core/process/processors/llm/api_utils.py +++ b/src/mmirage/core/process/processors/llm/api_utils.py @@ -28,92 +28,5 @@ def get_media_type(path: Path) -> str: return "image/webp" else: raise ValueError(f"Unsupported file extension: {ext}") - -def load_data_raw(manifest_path: Path) -> List[dict]: #TODO not useful, remove - """ - Load the raw JSONL manifest as a list of dicts. - """ - if not manifest_path.exists(): - raise FileNotFoundError( - f"Manifest not found: {manifest_path}" - ) - - records: List[dict] = [] - with manifest_path.open("r", encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if not line: - continue - try: - records.append(json.loads(line)) - except json.JSONDecodeError as e: - raise ValueError( - f"Invalid JSON on line {line_num} of {manifest_path}" - ) from e - - if not records: - raise RuntimeError("Manifest loaded but contains no records.") - - return records - - -def resolve_image_path(image_root: Path, value: str) -> Path: - """ - Resolve image paths safely, handling leading slashes. - """ - rel = value.lstrip("/") - return image_root / rel - - - -def load_data( - manifest_path: Path, - nb_samples: Optional[int] = None, - max_images_per_sample: int = 1, - ) -> Tuple[List[Tuple[str, Tuple[str, ...]]], List[str]]: #TODO Look if it is useful, if not remove - """ - Load dataset examples and encode images. - - Returns: - examples: List of (text, (img_b64, ...)) - paths: List of absolute image paths used - """ - raw_records = load_data_raw(manifest_path) - records = raw_records[:nb_samples] if nb_samples else raw_records - - examples: List[Tuple[str, Tuple[str, ...]]] = [] - used_paths: List[str] = [] - - for rec in tqdm(records, desc="Loading dataset"): - text = str(rec.get("text", "")).strip() - if not text: - continue - - image_paths: List[Path] = [] - for m in rec.get("modalities", []): - if m.get("type") == "image" and m.get("value"): - image_paths.append(resolve_image_path(m["value"])) - if len(image_paths) >= max_images_per_sample: - break - - if not image_paths: - continue - - try: - encoded_images = tuple( - encode_image_to_base64(p) for p in image_paths - ) - except Exception as e: - # Skip corrupted or unreadable images - print(f"[WARN] Skipping sample due to image error: {e}") - continue - - examples.append((text, encoded_images)) - used_paths.extend(str(p) for p in image_paths) - - if not examples: - raise RuntimeError("No valid examples loaded.") - - return examples, used_paths diff --git a/src/mmirage/core/process/processors/llm/openai_batch_client.py b/src/mmirage/core/process/processors/llm/openai_batch_client.py index a4b3e02..e0c82b6 100644 --- a/src/mmirage/core/process/processors/llm/openai_batch_client.py +++ b/src/mmirage/core/process/processors/llm/openai_batch_client.py @@ -7,7 +7,6 @@ from pydantic import BaseModel from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient -from mmirage.core.process.processors.llm.api_utils import get_media_type, load_data, encode_image_to_base64 from mmirage.core.process.processors.llm.config import LLMOutputVar from mmirage.core.process.variables import VariableEnvironment From 141d2149467e477ecd3bc0eac5f94d2802b8ddf3 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sun, 22 Mar 2026 21:42:40 +0100 Subject: [PATCH 04/45] small modifications and cleaning --- .../processors/llm/api_batch_client.py | 60 ++++--------------- .../processors/llm/openai_batch_client.py | 15 +++-- 2 files changed, 21 insertions(+), 54 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/api_batch_client.py b/src/mmirage/core/process/processors/llm/api_batch_client.py index 2a3ac0d..346f448 100644 --- a/src/mmirage/core/process/processors/llm/api_batch_client.py +++ b/src/mmirage/core/process/processors/llm/api_batch_client.py @@ -1,6 +1,6 @@ -from typing import List, Optional, Dict, Any, Type +from typing import List, Optional, Dict, Any, Tuple, Type from pydantic import BaseModel from pathlib import Path @@ -15,70 +15,30 @@ def __init__(self, model_name: str, api_key: str, provider: str): self.api_key = api_key self.provider = provider - @abstractmethod - def build_request( - self, - *, - prompt: str, - image_b64: str = None, - media_type: str = None, - request_id: int, - system_prompt: str = None, - output_schema: Optional[Type[BaseModel]] = None, - ) -> dict: - """ - Build a single API request object based on the provider. - - Args: - text: The input text to send to the LLM. - image_b64: Optional base64-encoded image string for multimodal models. - request_id: Unique identifier for this request. - - Returns: - A dict representing the API request payload. - """ - pass - @abstractmethod - def submit_batches(self, output_dir: Path) -> None: + def submit_batches(self) -> None: """ Submit batches of requests to the LLM API and save responses. - - Args: - batches_dir: Directory containing batch request files. - output_dir: Directory to save API responses. """ pass - - - + @abstractmethod def process_dataset(self, - *, - nb_samples: Optional[int] = None, + batch: List[Tuple[str, Tuple[Tuple[str, str], ...]]], ) -> None: """ - Build batch JSONL files - - Writes one or more files: part_1.jsonl, part_2.jsonl, ... + Build batch JSONL files for OpenAI Batch API. Writes one or more files: part_1.jsonl, part_2.jsonl, ... Splits by MAX_PART_SIZE_BYTES. + + Args: + batch: List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) """ pass - - @abstractmethod - def await_and_collect_batch_outputs(self, batches_dir: Path, output_dir: Path) -> None: + def await_and_collect_batch_outputs(self) -> None: """ Wait for API responses and collect outputs into VariableEnvironments. - - Args: - batches_dir: Directory containing batch request files. - output_dir: Directory where API responses are saved. """ - pass - - - - + pass \ No newline at end of file diff --git a/src/mmirage/core/process/processors/llm/openai_batch_client.py b/src/mmirage/core/process/processors/llm/openai_batch_client.py index e0c82b6..b1fe98b 100644 --- a/src/mmirage/core/process/processors/llm/openai_batch_client.py +++ b/src/mmirage/core/process/processors/llm/openai_batch_client.py @@ -24,7 +24,12 @@ class OpenAIBatchClient(APIBatchClient): - def __init__(self, model_name: str, api_key: str, output_dir: Path): + def __init__( + self, + model_name: str, + api_key: str, + output_dir: Path, + ): super().__init__(model_name=model_name, api_key=api_key, provider="openai") if not self.api_key: @@ -106,7 +111,7 @@ def build_request( # --------------------------------------------------------------------- - # Batch construction TODO + # Batch construction # --------------------------------------------------------------------- @@ -137,8 +142,10 @@ def process_dataset(self, ): # Enforce one image per request TODO : allow more than one image - image_b64 = encoded_images[0][0] - media_type = encoded_images[0][1] + image_b64 = media_type = None + if encoded_images: + image_b64 = encoded_images[0][0] + media_type = encoded_images[0][1] req = self.build_request( prompt=text, image_b64=image_b64, From 7811496bb01cd756bb146523ea9fcc2293a46f34 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sun, 22 Mar 2026 21:53:15 +0100 Subject: [PATCH 05/45] reset from the main --- .../processors/llm/api_batch_client.py | 44 --- .../core/process/processors/llm/api_utils.py | 32 -- .../processors/llm/claude_batch_client.py | 18 - .../core/process/processors/llm/config.py | 40 ++- .../process/processors/llm/llm_processor.py | 40 +-- .../processors/llm/openai_batch_client.py | 323 ------------------ 6 files changed, 38 insertions(+), 459 deletions(-) delete mode 100644 src/mmirage/core/process/processors/llm/api_batch_client.py delete mode 100644 src/mmirage/core/process/processors/llm/api_utils.py delete mode 100644 src/mmirage/core/process/processors/llm/claude_batch_client.py delete mode 100644 src/mmirage/core/process/processors/llm/openai_batch_client.py diff --git a/src/mmirage/core/process/processors/llm/api_batch_client.py b/src/mmirage/core/process/processors/llm/api_batch_client.py deleted file mode 100644 index 346f448..0000000 --- a/src/mmirage/core/process/processors/llm/api_batch_client.py +++ /dev/null @@ -1,44 +0,0 @@ - - -from typing import List, Optional, Dict, Any, Tuple, Type -from pydantic import BaseModel -from pathlib import Path - - -from mmirage.core.process.variables import VariableEnvironment -from abc import ABC, abstractmethod - -class APIBatchClient(ABC): - - def __init__(self, model_name: str, api_key: str, provider: str): - self.model_name = model_name - self.api_key = api_key - self.provider = provider - - - @abstractmethod - def submit_batches(self) -> None: - """ - Submit batches of requests to the LLM API and save responses. - """ - pass - - @abstractmethod - def process_dataset(self, - batch: List[Tuple[str, Tuple[Tuple[str, str], ...]]], - ) -> None: - """ - Build batch JSONL files for OpenAI Batch API. Writes one or more files: part_1.jsonl, part_2.jsonl, ... - Splits by MAX_PART_SIZE_BYTES. - - Args: - batch: List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) - """ - pass - - @abstractmethod - def await_and_collect_batch_outputs(self) -> None: - """ - Wait for API responses and collect outputs into VariableEnvironments. - """ - pass \ No newline at end of file diff --git a/src/mmirage/core/process/processors/llm/api_utils.py b/src/mmirage/core/process/processors/llm/api_utils.py deleted file mode 100644 index e13d359..0000000 --- a/src/mmirage/core/process/processors/llm/api_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -from pathlib import Path -import base64, json -from typing import List, Optional, Tuple - -import tqdm - - -def encode_image_to_base64(path: Path) -> str: - """ - Read an image from disk and return base64-encoded string. - """ - if not path.exists(): - raise FileNotFoundError(f"Image not found: {path}") - - with path.open("rb") as f: - return base64.b64encode(f.read()).decode("utf-8") - -def get_media_type(path: Path) -> str: - """ - Get the media type (MIME type) of a file based on its extension. - """ - ext = path.suffix.lower() - if ext in [".jpg", ".jpeg"]: - return "image/jpeg" - elif ext == ".png": - return "image/png" - elif ext == ".webp": - return "image/webp" - else: - raise ValueError(f"Unsupported file extension: {ext}") - - diff --git a/src/mmirage/core/process/processors/llm/claude_batch_client.py b/src/mmirage/core/process/processors/llm/claude_batch_client.py deleted file mode 100644 index ebd88b6..0000000 --- a/src/mmirage/core/process/processors/llm/claude_batch_client.py +++ /dev/null @@ -1,18 +0,0 @@ -from pathlib import Path - -import anthropic - -from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient - - - -class AnthropicBatchClient(APIBatchClient): - def __init__(self, model_name: str, api_key: str, output_dir: Path): - super().__init__(model_name=model_name, api_key=api_key, provider="anthropic") - - if not self.api_key: - raise SystemExit( - "ANTHROPIC_API_KEY is not set. Please export it before running." - ) - self.client = anthropic.Anthropic(api_key=self.api_key) - self.output_dir = output_dir \ No newline at end of file diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index ab769f1..859552d 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field import logging +import os from typing import Dict, Optional, Sequence, Type, Any, List from pydantic import BaseModel, create_model @@ -15,6 +16,39 @@ env = Environment() +def _parse_tp_size_from_env() -> int: + """Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable. + + Defensively parses the environment variable, handling invalid values: + - Returns 1 if the variable is None or empty + - Strips whitespace before parsing + - Returns 1 for non-integer values + - Returns 1 for values <= 0 + + Returns: + Tensor parallelism size (>= 1), defaults to 1 on any parsing error. + """ + env_value = os.environ.get("SLURM_GPUS_ON_NODE") + if not env_value: + return 1 + + try: + tp_size = int(env_value.strip()) + # Ensure tp_size is positive (must be >= 1) + if tp_size <= 0: + logger.warning( + f"Invalid SLURM_GPUS_ON_NODE value '{env_value}' (must be > 0), defaulting tp_size to 1" + ) + return 1 + return tp_size + except ValueError: + # ValueError: invalid integer format + logger.warning( + f"Invalid SLURM_GPUS_ON_NODE value '{env_value}', defaulting tp_size to 1" + ) + return 1 + + @dataclass class SGLangServerArgs: """Server arguments for SGLang engine. @@ -27,7 +61,7 @@ class SGLangServerArgs: """ model_path: str = "none" - tp_size: int = 1 + tp_size: int = field(default_factory=_parse_tp_size_from_env) trust_remote_code: bool = True disable_custom_all_reduce: bool = False @@ -49,10 +83,6 @@ class SGLangLLMConfig(BaseProcessorConfig): default_sampling_params: Dict[str, Any] = field(default_factory=dict) chat_template: str = "" # Empty means use tokenizer's default - provider: str = "sglang" # options: "sglang", "anthropic", "openai". Used for routing to the correct LLM provider - api_model_name: str = "gpt-4o" # model name to use when provider is API-based (e.g., OpenAI, Anthropic) - api_key: str = "" # API key for API-based providers - @dataclass class LLMOutputVar(OutputVar): diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 44f03b1..56afd17 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -5,7 +5,6 @@ from dataclasses import asdict import json import logging -from pathlib import Path from typing import Any, List, Tuple import jinja2 @@ -14,10 +13,7 @@ from mmirage.core.process.base import BaseProcessor, ProcessorRegistry from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig -from mmirage.core.process.processors.llm.openai_batch_client import OpenAIBatchClient from mmirage.core.process.variables import VariableEnvironment -from mmirage.core.process.processors.llm.api_utils import encode_image_to_base64, get_media_type - try: from typing import override # Python 3.12+ @@ -61,18 +57,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: engine_args: Configuration for SGLang server and sampling parameters. **kwargs: Additional arguments passed to base class. """ - - - - super().__init__(engine_args, **kwargs) - - if self.engine_args.provider == "openai": - self.llm = OpenAIBatchClient(self.engine_args.api_model_name, self.engine_args.api_key) - elif self.engine_args.provider == "anthropic": - pass - - # Default to SGLang Engine self.llm = sgl.Engine(**asdict(engine_args.server_args)) self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, @@ -145,7 +130,7 @@ def _get_image_token(self) -> str: return IMAGE_TOKENS.get(self.chat_template, "") @override - def batch_process_sample( + def batch_process_sample( self, batch: List[VariableEnvironment], output_var: LLMOutputVar ) -> List[VariableEnvironment]: """Process a batch of variable environments to generate LLM outputs. @@ -162,27 +147,6 @@ def batch_process_sample( RuntimeError: If output batch size doesn't match input batch size. """ nb_samples = len(batch) - results: dict[int, VariableEnvironment] = {} - - # ---- For API-based providers ---- - if self.provider in ["openai", "anthropic"]: - # dataset_examples = List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) - # prompt_text is generated with jinja rendering and images are encoded to base64 with their media types - batch_prompts: List[Tuple[str, Tuple[Tuple[str, str], ...]]] = [] - for var_env in batch: - jinja_template = jinja2.Template(output_var.prompt) - base_prompt = jinja_template.render(**var_env.to_dict()) - - image_paths = var_env.get_images() - encoded_images = tuple((encode_image_to_base64(p), get_media_type(Path(p))) for p in image_paths) if image_paths else () - batch_prompts.append((base_prompt, encoded_images)) - - self.llm.process_dataset(batch_prompts) - - self.llm.submit_batches(self.llm.output_dir, nb_samples=nb_samples) - self.llm.await_and_collect_batch_outputs(self.llm.output_dir) - - # ---- For SGLang Engine provider ---- # Prepare sampling params sampling_params_output = self.sampling_params.copy() @@ -206,6 +170,8 @@ def batch_process_sample( else: text_only_indices.append(i) + results: dict[int, VariableEnvironment] = {} + # Text-only batch if text_only_indices: text_only_envs = [batch[i] for i in text_only_indices] diff --git a/src/mmirage/core/process/processors/llm/openai_batch_client.py b/src/mmirage/core/process/processors/llm/openai_batch_client.py deleted file mode 100644 index b1fe98b..0000000 --- a/src/mmirage/core/process/processors/llm/openai_batch_client.py +++ /dev/null @@ -1,323 +0,0 @@ -from openai import OpenAI -import json, time, re -from tqdm import tqdm -from typing import List, Tuple, Optional, Type -import json -from pathlib import Path -from pydantic import BaseModel - -from mmirage.core.process.processors.llm.api_batch_client import APIBatchClient -from mmirage.core.process.processors.llm.config import LLMOutputVar -from mmirage.core.process.variables import VariableEnvironment - - - -# --------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------- - -MAX_TOKENS = 1000 -# 50 MB per batch part is well below OpenAI limits and avoids failures -MAX_PART_SIZE_BYTES = int(0.05 * 1024 ** 3) - - - - -class OpenAIBatchClient(APIBatchClient): - def __init__( - self, - model_name: str, - api_key: str, - output_dir: Path, - ): - super().__init__(model_name=model_name, api_key=api_key, provider="openai") - - if not self.api_key: - raise SystemExit( - "OPENAI_API_KEY is not set. Please export it before running." - ) - self.client = OpenAI(api_key=self.api_key) - self.output_dir = output_dir - self.batches_dir = output_dir / "batches" #TODO to implement in process_dataset and submit_batches - - - # --------------------------------------------------------------------- - # Request builder - # --------------------------------------------------------------------- - - - def build_request( - self, - *, - prompt: str, - image_b64: str = None, - media_type: str = None, - request_id: int, - system_prompt: str = None, - output_schema: Optional[Type[BaseModel]] = None, - ) -> dict: - """Build a single OpenAI Batch API request object. - - Args: - prompt: The fully-rendered user prompt (Jinja2 already applied). - image_b64: Optional base64-encoded image for multimodal requests. - media_type: MIME type of the image (e.g., "image/jpeg"). - request_id: Unique identifier used as custom_id. - system_prompt: Optional system message prepended to the conversation. - output_schema: Optional Pydantic model used to enforce a JSON response - via OpenAI's structured-output ``response_format``. - """ - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - - if image_b64 is not None and media_type is not None: - user_content = [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:{media_type};base64,{image_b64}"}, - }, - ] - else: - user_content = prompt - - messages.append({"role": "user", "content": user_content}) - - body: dict = { - "model": self.model_name, - "messages": messages, - "max_tokens": MAX_TOKENS, - } - - if output_schema is not None: - body["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": output_schema.__name__, - "strict": True, - "schema": output_schema.model_json_schema(), - }, - } - - return { - "custom_id": f"request-{request_id}", - "method": "POST", - "url": "/v1/chat/completions", - "body": body, - } - - - - - # --------------------------------------------------------------------- - # Batch construction - # --------------------------------------------------------------------- - - - - - def process_dataset(self, - batch: List[Tuple[str, Tuple[Tuple[str, str], ...]]], - ) -> None: - """ - Build batch JSONL files for OpenAI Batch API. Writes one or more files: part_1.jsonl, part_2.jsonl, ... - Splits by MAX_PART_SIZE_BYTES. - - Args: - batch: List of (prompt_text, ((encoded_image1, media_type1), (encoded_image2, media_type2), ...)) - """ - self.output_dir.mkdir(parents=True, exist_ok=True) - - part_idx = 1 - bytes_in_part = 0 - - part_path = self.output_dir / f"part_{part_idx}.jsonl" - part_file = part_path.open("w", encoding="utf-8") - - for i, (text, encoded_images) in tqdm( - enumerate(batch, start=1), - total=len(batch), - desc="Building batch requests", - ): - - # Enforce one image per request TODO : allow more than one image - image_b64 = media_type = None - if encoded_images: - image_b64 = encoded_images[0][0] - media_type = encoded_images[0][1] - req = self.build_request( - prompt=text, - image_b64=image_b64, - media_type=media_type, - request_id=i, - system_prompt=None,# TODO - output_schema=None,# TODO - ) - - line = json.dumps(req, ensure_ascii=False) + "\n" - size = len(line.encode("utf-8")) - - if bytes_in_part + size > MAX_PART_SIZE_BYTES: - part_file.close() - part_idx += 1 - bytes_in_part = 0 - part_path = self.output_dir / f"part_{part_idx}.jsonl" - part_file = part_path.open("w", encoding="utf-8") - - part_file.write(line) - bytes_in_part += size - - part_file.close() - print(f"[DONE] Created {part_idx} batch file(s) in {self.output_dir}") - - - - - - # --------------------------------------------------------------------- - # Batch submission - # --------------------------------------------------------------------- - - - def submit_batches(self, batches_dir: Path) -> None: - """Submit batch files to OpenAI Batch API.""" - - parts = sorted(batches_dir.glob("*.jsonl")) - - if not parts: - raise SystemExit(f"No batch files found in {batches_dir}") - - - # submit batches sequentially with progress bar; persist batch IDs for reproducibility - for part in tqdm(parts, desc="Submitting batches"): - batch_id_file = self.output_dir / f"batch_id_{part.name}.txt" - - # Skip if already submitted - if batch_id_file.exists(): - print(f"[SKIP] {part.name} already submitted") - continue - - # Upload batch input file - with part.open("rb") as fh: - uploaded = self.client.files.create( - file=fh, - purpose="batch", - ) - - # Create batch job - batch = self.client.batches.create( - input_file_id=uploaded.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={ - "description": f"Dataset augmentation - {part.name}", - }, - ) - - # Persist batch ID (critical for reproducibility) - batch_id_file.write_text(batch.id) - - - print(f"[SUBMITTED] {part.name} → batch_id={batch.id}") - - print("All batches submitted.") - - - - - - # --------------------------------------------------------------------- - # Collect Outputs - # --------------------------------------------------------------------- - - - def __wait_for_output(self, batch_id: str, max_wait_s: int = 86400, poll_s: int = 30): - waited = 0 - while True: - b = self.client.batches.retrieve(batch_id) - print(f"[{batch_id}] status={b.status} out={b.output_file_id} err={b.error_file_id}") - if b.output_file_id: - return b - if b.status in ("failed", "cancelled", "expired"): - raise SystemExit(f"Batch ended with status: {b.status}") - time.sleep(poll_s) - waited += poll_s - if waited >= max_wait_s: - raise SystemExit("Timed out waiting for output_file_id") - - def __part_number_from_filename(self, p: Path) -> int: - m = re.search(r"batch_id_part_(\d+)\.jsonl\.txt$", p.name) - return int(m.group(1)) if m else 0 - - def __extract_messages(api_responses: List[dict]) -> List[str]: - return [ - rec["response"]["body"]["choices"][0]["message"]["content"].strip() - for rec in api_responses - ] - - def __save_part_output(self, b, part_num: int, output_dir : Path) -> List[dict]: - text = self.client.files.content(b.output_file_id).text - part_path = output_dir / f"api_response_part_{part_num}.jsonl" - part_path.write_text(text) - print(f"[saved] {part_path}") - return [json.loads(line) for line in text.splitlines() if line.strip()] - - - def await_and_collect_batch_outputs(self) -> None: - """Wait for batch completions and download outputs.""" - batch_id_files = sorted(self.output_dir.glob("batch_id_*.txt")) - - if not batch_id_files: - raise SystemExit(f"No batch ID files found in {self.output_dir}") - - - all_records = [] - total_prompt = total_completion = 0 - - - for id_file in batch_id_files: - part_num = self.__part_number_from_filename(id_file) - batch_id = id_file.read_text().strip() - - b = self.__wait_for_output(batch_id) - records = self.__save_part_output(b, part_num) - all_records.extend(records) - - # accumulate actual usage - for rec in records: - usage = rec.get("response", {}).get("body", {}).get("usage", {}) - total_prompt += int(usage.get("prompt_tokens", 0)) - total_completion += int(usage.get("completion_tokens", 0)) - - - # merged outputs - all_path = self.output_dir / "api_response_all.jsonl" - with all_path.open("w", encoding="utf-8") as fout: - for rec in all_records: - fout.write(json.dumps(rec, ensure_ascii=False) + "\n") - print(f"[merged] {all_path} ({len(all_records)} responses)") - - - # optional: also save the plain texts - texts_path = self.output_dir / "messages_all.txt" - with texts_path.open("w", encoding="utf-8") as ftxt: - for msg in self.__extract_messages(all_records): - ftxt.write(msg + "\n\n") - print(f"[texts] {texts_path}") - - - - - for batch_id_file in tqdm(batch_id_files, desc="Waiting for batches"): - batch_id = batch_id_file.read_text().strip() - batch = self.__wait_for_output(batch_id) - - # Download output file - output_path = self.output_dir / f"output_{batch_id}.jsonl" - with output_path.open("wb") as fh: - self.client.files.download(batch.output_file_id, fh) - - print(f"[DOWNLOADED] Batch {batch_id} output to {output_path}") - - From 8e085a34f487ab1fc1c6cbcad5ce925a11d971c6 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 23 Mar 2026 00:54:37 +0100 Subject: [PATCH 06/45] implementation of config classes and agnostic-provider classes --- src/mmirage/config/batch_provider.py | 72 +++++++++++ src/mmirage/core/process/batch/__init__.py | 11 ++ src/mmirage/core/process/batch/adapter.py | 133 +++++++++++++++++++++ src/mmirage/core/process/batch/registry.py | 65 ++++++++++ 4 files changed, 281 insertions(+) create mode 100644 src/mmirage/config/batch_provider.py create mode 100644 src/mmirage/core/process/batch/__init__.py create mode 100644 src/mmirage/core/process/batch/adapter.py create mode 100644 src/mmirage/core/process/batch/registry.py diff --git a/src/mmirage/config/batch_provider.py b/src/mmirage/config/batch_provider.py new file mode 100644 index 0000000..0256787 --- /dev/null +++ b/src/mmirage/config/batch_provider.py @@ -0,0 +1,72 @@ +"""Provider-agnostic batch configuration contracts. + +This module defines the shared configuration shape used by any future batch +submission provider (OpenAI, Anthropic, etc.). +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class BatchRetryPolicy: + """Retry behavior used by provider-neutral batch submission orchestration. + + Attributes: + max_attempts: Maximum number of submission attempts for retryable errors. + initial_backoff_seconds: Delay before the first retry attempt. + backoff_multiplier: Multiplicative factor for subsequent retry delays. + """ + + max_attempts: int = 3 + initial_backoff_seconds: float = 2.0 + backoff_multiplier: float = 2.0 + + def __post_init__(self) -> None: + if self.max_attempts < 1: + raise ValueError("max_attempts must be >= 1") + if self.initial_backoff_seconds < 0: + raise ValueError("initial_backoff_seconds must be >= 0") + if self.backoff_multiplier < 1: + raise ValueError("backoff_multiplier must be >= 1") + + +@dataclass +class BatchProviderConfig: + """Shared contract for provider-specific batch configuration. + + Concrete provider configs should inherit from this dataclass and extend it + with provider-specific settings. The fields here are intentionally provider + neutral so chunking/submission orchestration can run through one typed path. + + Attributes: + provider: Provider identifier (for example, "openai" or "anthropic"). + enabled: Whether batch submission mode is enabled. + max_chunk_bytes: Maximum serialized request bytes per chunk. + Defaults to 50 MB. + max_requests_per_chunk: Optional hard cap on number of requests in a + chunk. If None, no request-count cap is enforced. + metadata_output_path: Path where submission metadata artifacts are saved. + retry_policy: Retry policy used by the shared batch layer. + extras: Provider-specific knobs that do not belong in the shared fields. + credentials: Provider credentials required to submit chunks. + """ + + provider: str + enabled: bool = True + max_chunk_bytes: int = 50 * 1024 * 1024 + max_requests_per_chunk: Optional[int] = None + metadata_output_path: str = "" + retry_policy: BatchRetryPolicy = field(default_factory=BatchRetryPolicy) + extras: Dict[str, Any] = field(default_factory=dict) + credentials: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.provider = self.provider.strip().lower() + + if not self.provider: + raise ValueError("provider must be a non-empty string") + if self.max_chunk_bytes < 1: + raise ValueError("max_chunk_bytes must be >= 1") + if self.max_requests_per_chunk is not None and self.max_requests_per_chunk < 1: + raise ValueError("max_requests_per_chunk must be >= 1 when provided") diff --git a/src/mmirage/core/process/batch/__init__.py b/src/mmirage/core/process/batch/__init__.py new file mode 100644 index 0000000..e36dc7d --- /dev/null +++ b/src/mmirage/core/process/batch/__init__.py @@ -0,0 +1,11 @@ +"""Provider-agnostic batch processing contracts and registry.""" + +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry + +__all__ = [ + "BatchSubmissionAdapter", + "BatchSubmissionResult", + "BatchAdapterFactory", + "BatchAdapterRegistry", +] diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py new file mode 100644 index 0000000..29ff1e1 --- /dev/null +++ b/src/mmirage/core/process/batch/adapter.py @@ -0,0 +1,133 @@ +"""Provider-agnostic batch submission adapter contracts. + +Adapters implement translation from internal request payloads into provider +request formats and normalize submission responses into a shared result shape. +""" + +import abc +from dataclasses import dataclass, field +from typing import Any, Dict, Mapping, Sequence, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig + + +@dataclass +class BatchSubmissionResult: + """Normalized result returned by any provider adapter after chunk submission. + + Attributes: + provider_batch_id: Provider-side identifier for the submitted job/batch. + status: Provider submission status normalized to a short string. + submitted_request_count: Number of requests accepted in this submission. + raw_response: Original provider response payload for traceability. + """ + + provider_batch_id: str + status: str + submitted_request_count: int + raw_response: Mapping[str, Any] = field(default_factory=dict) + + +class BatchSubmissionAdapter(abc.ABC): + """Abstract interface for provider-specific batch submission adapters. + + Implementations should be deterministic for request building and byte + estimation so chunk boundaries can be reproduced across retries. + """ + + required_credentials: Tuple[str, ...] = tuple() + + @property + @abc.abstractmethod + def adapter_name(self) -> str: + """Return a stable adapter identity string. + + The identity should remain stable across code changes that preserve + behavior and should change only when semantics diverge. + """ + raise NotImplementedError() + + @property + @abc.abstractmethod + def adapter_version(self) -> str: + """Return the adapter implementation version. + + This value is persisted in metadata artifacts to support auditing and + replay diagnostics across code revisions. + """ + raise NotImplementedError() + + @abc.abstractmethod + def build_request( + self, + custom_id: str, + payload: Mapping[str, Any], + config: BatchProviderConfig, + ) -> Mapping[str, Any]: + """Build a single provider-ready request object. + + Args: + custom_id: Stable request identifier used to map async results back + to source rows. + payload: Provider-neutral request payload assembled by the core + processing layer. + config: Provider configuration contract that may influence request + shaping. + + Returns: + A provider-specific request object represented as a mapping. + """ + raise NotImplementedError() + + @abc.abstractmethod + def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: + """Estimate serialized UTF-8 bytes for a request payload. + + The estimate must match or safely upper-bound the size produced by the + serializer used for submission so chunk boundaries are enforced + correctly. + + Args: + request: Provider request object returned by ``build_request``. + + Returns: + Estimated byte size for the serialized request. + """ + raise NotImplementedError() + + @abc.abstractmethod + def submit_chunk( + self, + chunk_id: str, + requests: Sequence[Mapping[str, Any]], + config: BatchProviderConfig, + ) -> Mapping[str, Any]: + """Submit one pre-chunked request group to the provider. + + Args: + chunk_id: Internal chunk identifier generated by orchestration. + requests: Provider-ready request objects belonging to this chunk. + config: Provider config containing credentials and submission knobs. + + Returns: + Raw provider response payload as a mapping. + """ + raise NotImplementedError() + + @abc.abstractmethod + def parse_submission_result( + self, + raw_result: Mapping[str, Any], + request_count: int, + ) -> BatchSubmissionResult: + """Normalize provider submission output into a shared result model. + + Args: + raw_result: Raw payload returned by ``submit_chunk``. + request_count: Number of requests submitted in the chunk. + + Returns: + A normalized ``BatchSubmissionResult`` for provider-neutral + orchestration and metadata persistence. + """ + raise NotImplementedError() diff --git a/src/mmirage/core/process/batch/registry.py b/src/mmirage/core/process/batch/registry.py new file mode 100644 index 0000000..84b0691 --- /dev/null +++ b/src/mmirage/core/process/batch/registry.py @@ -0,0 +1,65 @@ +"""Registry and factory for provider batch adapters.""" + +from typing import Dict, Type + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter + + +class BatchAdapterRegistry: + """Provider-to-adapter registry with factory helpers. + + This class centralizes provider registration and fail-fast adapter + instantiation with credential validation. + """ + + _registry: Dict[str, Type[BatchSubmissionAdapter]] = dict() + + @classmethod + def register(cls, provider: str, adapter_cls: Type[BatchSubmissionAdapter]) -> None: + """Register an adapter class under a provider key.""" + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + cls._registry[provider_key] = adapter_cls + + @classmethod + def clear(cls) -> None: + """Clear all registered adapters. + + Intended for tests and isolated bootstrapping logic. + """ + cls._registry.clear() + + @classmethod + def resolve(cls, provider: str) -> Type[BatchSubmissionAdapter]: + """Resolve a provider key to a registered adapter class.""" + provider_key = provider.strip().lower() + if provider_key not in cls._registry: + raise ValueError( + f"Unknown batch provider '{provider}'. " + f"Available providers: {list(cls._registry.keys())}" + ) + return cls._registry[provider_key] + + @classmethod + def create(cls, config: BatchProviderConfig) -> BatchSubmissionAdapter: + """Instantiate an adapter for a provider config with credential checks.""" + adapter_cls = cls.resolve(config.provider) + missing_credentials = [ + name for name in adapter_cls.required_credentials if not config.credentials.get(name) + ] + if missing_credentials: + raise ValueError( + f"Missing credentials for provider '{config.provider}': {missing_credentials}" + ) + return adapter_cls() + + +class BatchAdapterFactory: + """Compatibility alias around registry-based adapter creation.""" + + @classmethod + def from_config(cls, config: BatchProviderConfig) -> BatchSubmissionAdapter: + """Create an adapter from provider config via registry resolution.""" + return BatchAdapterRegistry.create(config) From f26f80458e767e1f89e1a7877ed55bb540c578a3 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 23 Mar 2026 01:09:49 +0100 Subject: [PATCH 07/45] pytests for previous implementations --- tests/test_batch_adapter_contracts.py | 130 ++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 tests/test_batch_adapter_contracts.py diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py new file mode 100644 index 0000000..ce8a520 --- /dev/null +++ b/tests/test_batch_adapter_contracts.py @@ -0,0 +1,130 @@ +import pytest + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry + + +class CompleteTestAdapter(BatchSubmissionAdapter): + required_credentials = tuple() + + @property + def adapter_name(self) -> str: + return "complete-test-adapter" + + @property + def adapter_version(self) -> str: + return "1.0.0" + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, "payload": dict(payload), "provider": config.provider} + + def estimate_request_bytes(self, request): + # Deterministic approximation for tests. + return len(str(request).encode("utf-8")) + + def submit_chunk(self, chunk_id, requests, config): + return { + "batch_id": f"{config.provider}-{chunk_id}", + "status": "submitted", + "requests": len(requests), + } + + def parse_submission_result(self, raw_result, request_count): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["batch_id"]), + status=str(raw_result["status"]), + submitted_request_count=request_count, + raw_response=raw_result, + ) + + +class CredentialedTestAdapter(CompleteTestAdapter): + required_credentials = ("api_key",) + + +class IncompleteTestAdapter(BatchSubmissionAdapter): + @property + def adapter_name(self) -> str: + return "incomplete" + + @property + def adapter_version(self) -> str: + return "0.0.0" + + def build_request(self, custom_id, payload, config): + return {} + + def estimate_request_bytes(self, request): + return 0 + + def submit_chunk(self, chunk_id, requests, config): + return {} + + +@pytest.fixture(autouse=True) +def clear_batch_adapter_registry(): + BatchAdapterRegistry.clear() + yield + BatchAdapterRegistry.clear() + + +def test_adapter_interface_is_abstract(): + with pytest.raises(TypeError): + BatchSubmissionAdapter() + + +def test_incomplete_adapter_fails_interface_compliance(): + with pytest.raises(TypeError): + IncompleteTestAdapter() + + +def test_complete_adapter_is_interface_compliant(): + adapter = CompleteTestAdapter() + config = BatchProviderConfig(provider="unit") + + request = adapter.build_request(custom_id="req-1", payload={"x": 1}, config=config) + assert request["custom_id"] == "req-1" + + estimated_bytes = adapter.estimate_request_bytes(request) + assert estimated_bytes > 0 + + raw_result = adapter.submit_chunk(chunk_id="chunk-1", requests=[request], config=config) + parsed = adapter.parse_submission_result(raw_result=raw_result, request_count=1) + + assert parsed.provider_batch_id == "unit-chunk-1" + assert parsed.status == "submitted" + assert parsed.submitted_request_count == 1 + + +def test_factory_resolves_registered_provider(): + BatchAdapterRegistry.register("unit", CompleteTestAdapter) + config = BatchProviderConfig(provider="unit") + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CompleteTestAdapter) + + +def test_factory_raises_for_unknown_provider(): + config = BatchProviderConfig(provider="not-registered") + + with pytest.raises(ValueError, match="Unknown batch provider"): + BatchAdapterFactory.from_config(config) + + +def test_factory_raises_for_missing_required_credentials(): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + config = BatchProviderConfig(provider="unit", credentials={}) + + with pytest.raises(ValueError, match="Missing credentials"): + BatchAdapterFactory.from_config(config) + + +def test_factory_creates_adapter_when_credentials_are_present(): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + config = BatchProviderConfig(provider="unit", credentials={"api_key": "secret"}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CredentialedTestAdapter) From 2003d34b01575b3ec7b38cf19b7f2ea246a54c45 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 23 Mar 2026 04:57:33 +0100 Subject: [PATCH 08/45] OpenAI API implementation --- src/mmirage/config/openai_batch.py | 44 +++++++ src/mmirage/core/process/batch/__init__.py | 4 + .../core/process/batch/openai_adapter.py | 114 ++++++++++++++++++ src/mmirage/core/process/batch/registry.py | 14 +++ 4 files changed, 176 insertions(+) create mode 100644 src/mmirage/config/openai_batch.py create mode 100644 src/mmirage/core/process/batch/openai_adapter.py diff --git a/src/mmirage/config/openai_batch.py b/src/mmirage/config/openai_batch.py new file mode 100644 index 0000000..65166af --- /dev/null +++ b/src/mmirage/config/openai_batch.py @@ -0,0 +1,44 @@ +"""OpenAI-specific batch configuration.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from mmirage.config.batch_provider import BatchProviderConfig + + +@dataclass +class OpenAIBatchConfig(BatchProviderConfig): + """OpenAI Batch API configuration. + + Attributes: + provider: Fixed provider identifier for OpenAI. + model: Model name used in each chat completion request body. + batch_endpoint: Target endpoint used by OpenAI batch jobs. + completion_window: OpenAI completion window value. + base_url: Optional base URL, useful for API-compatible gateways. + metadata: Metadata sent on batch creation. + """ + + provider: str = "openai" + model: str = "gpt-4.1-mini" + batch_endpoint: str = "/v1/chat/completions" + completion_window: Literal["24h"] = "24h" + base_url: Optional[str] = None + metadata: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__post_init__() + + if not self.model.strip(): + raise ValueError("model must be a non-empty string") + if not self.batch_endpoint.startswith("/"): + raise ValueError("batch_endpoint must start with '/'") + + # Mirror OpenAI-specific fields into generic extras for provider-neutral consumers. + self.extras.setdefault("model", self.model) + self.extras.setdefault("batch_endpoint", self.batch_endpoint) + self.extras.setdefault("completion_window", self.completion_window) + if self.base_url: + self.extras.setdefault("base_url", self.base_url) + if self.metadata: + self.extras.setdefault("metadata", dict(self.metadata)) diff --git a/src/mmirage/core/process/batch/__init__.py b/src/mmirage/core/process/batch/__init__.py index e36dc7d..fc4702a 100644 --- a/src/mmirage/core/process/batch/__init__.py +++ b/src/mmirage/core/process/batch/__init__.py @@ -1,11 +1,15 @@ """Provider-agnostic batch processing contracts and registry.""" from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry +from mmirage.config.openai_batch import OpenAIBatchConfig __all__ = [ "BatchSubmissionAdapter", "BatchSubmissionResult", + "OpenAIBatchAdapter", + "OpenAIBatchConfig", "BatchAdapterFactory", "BatchAdapterRegistry", ] diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py new file mode 100644 index 0000000..2046d06 --- /dev/null +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -0,0 +1,114 @@ +"""Concrete OpenAI implementation of batch submission contracts.""" + +import io +import json +from typing import Any, Mapping, Sequence + +from openai import OpenAI + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult + + +class OpenAIBatchAdapter(BatchSubmissionAdapter): + """Provider adapter for OpenAI Batch API.""" + + required_credentials = ("api_key",) + + @property + def adapter_name(self) -> str: + return "openai-batch-adapter" + + @property + def adapter_version(self) -> str: + return "1.0.0" + + def build_request( + self, + custom_id: str, + payload: Mapping[str, Any], + config: BatchProviderConfig, + ) -> Mapping[str, Any]: + openai_config = self._as_openai_config(config) + body = dict(payload) + body.setdefault("model", openai_config.model) + + return { + "custom_id": custom_id, + "method": "POST", + "url": openai_config.batch_endpoint, + "body": body, + } + + def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: + serialized = json.dumps(request, ensure_ascii=False, separators=(",", ":")) + return len(serialized.encode("utf-8")) + + def submit_chunk( + self, + chunk_id: str, + requests: Sequence[Mapping[str, Any]], + config: BatchProviderConfig, + ) -> Mapping[str, Any]: + openai_config = self._as_openai_config(config) + + client_kwargs = {"api_key": openai_config.credentials.get("api_key", "")} + if openai_config.base_url: + client_kwargs["base_url"] = openai_config.base_url + client = OpenAI(**client_kwargs) + + jsonl_lines = [ + json.dumps(req, ensure_ascii=False, separators=(",", ":")) for req in requests + ] + jsonl_payload = "\n".join(jsonl_lines).encode("utf-8") + + file_response = client.files.create( + file=(f"batch_chunk-{chunk_id}.jsonl", io.BytesIO(jsonl_payload)), + purpose="batch", + ) + + metadata = dict(openai_config.metadata) + metadata["chunk_id"] = chunk_id + + batch_response = client.batches.create( + input_file_id=self._read_attr(file_response, "id"), + endpoint=openai_config.batch_endpoint, + completion_window=openai_config.completion_window, + metadata=metadata, + ) + + return { + "id": self._read_attr(batch_response, "id"), + "status": self._read_attr(batch_response, "status"), + "endpoint": self._read_attr(batch_response, "endpoint"), + "input_file_id": self._read_attr(file_response, "id"), + "chunk_id": chunk_id, + } + + def parse_submission_result( + self, + raw_result: Mapping[str, Any], + request_count: int, + ) -> BatchSubmissionResult: + batch_id = str(raw_result.get("id") or raw_result.get("batch_id") or "") + status = str(raw_result.get("status") or "unknown") + + return BatchSubmissionResult( + provider_batch_id=batch_id, + status=status, + submitted_request_count=request_count, + raw_response=dict(raw_result), + ) + + @staticmethod + def _as_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: + if isinstance(config, OpenAIBatchConfig): + return config + raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") + + @staticmethod + def _read_attr(obj: Any, key: str) -> Any: + if isinstance(obj, Mapping): + return obj.get(key) + return getattr(obj, key) diff --git a/src/mmirage/core/process/batch/registry.py b/src/mmirage/core/process/batch/registry.py index 84b0691..4267375 100644 --- a/src/mmirage/core/process/batch/registry.py +++ b/src/mmirage/core/process/batch/registry.py @@ -14,6 +14,18 @@ class BatchAdapterRegistry: """ _registry: Dict[str, Type[BatchSubmissionAdapter]] = dict() + _bootstrapped: bool = False + + @classmethod + def _bootstrap_builtin_adapters(cls) -> None: + if cls._bootstrapped: + return + + # Local import avoids import cycles while ensuring built-ins are available. + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + cls.register("openai", OpenAIBatchAdapter) + cls._bootstrapped = True @classmethod def register(cls, provider: str, adapter_cls: Type[BatchSubmissionAdapter]) -> None: @@ -30,10 +42,12 @@ def clear(cls) -> None: Intended for tests and isolated bootstrapping logic. """ cls._registry.clear() + cls._bootstrapped = False @classmethod def resolve(cls, provider: str) -> Type[BatchSubmissionAdapter]: """Resolve a provider key to a registered adapter class.""" + cls._bootstrap_builtin_adapters() provider_key = provider.strip().lower() if provider_key not in cls._registry: raise ValueError( From 95ae28cebcfe6e3e96edec5ec89a141c4380fe27 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 23 Mar 2026 14:06:43 +0100 Subject: [PATCH 09/45] unit tests for OpenAI implementation --- tests/test_openai_batch_adapter.py | 153 +++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 tests/test_openai_batch_adapter.py diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py new file mode 100644 index 0000000..59caeeb --- /dev/null +++ b/tests/test_openai_batch_adapter.py @@ -0,0 +1,153 @@ +import json + +import pytest + +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry + + +def test_openai_build_request_matches_expected_structure(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + config = OpenAIBatchConfig(model="gpt-4.1-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0, + } + + request = adapter.build_request(custom_id="row-001", payload=payload, config=config) + + assert request == { + "custom_id": "row-001", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-4.1-mini", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0, + }, + } + + +def test_openai_estimate_request_bytes_matches_utf8_json_size(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + adapter = OpenAIBatchAdapter() + request = { + "custom_id": "accented", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"message": "caf\u00e9"}, + } + + expected = len( + json.dumps(request, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + ) + + assert adapter.estimate_request_bytes(request) == expected + + +def test_openai_submit_chunk_uses_mocked_openai_client(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeFiles: + def create(self, *, file, purpose): + file_name, file_obj = file + assert file_name == "batch_chunk-chunk-01.jsonl" + assert purpose == "batch" + file_content = file_obj.read().decode("utf-8") + captured["jsonl"] = file_content + + class _FileResp: + id = "file_123" + + return _FileResp() + + class FakeBatches: + def create(self, **kwargs): + captured["batch_create_kwargs"] = kwargs + + class _BatchResp: + id = "batch_123" + status = "validating" + endpoint = kwargs["endpoint"] + + return _BatchResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.files = FakeFiles() + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig( + model="gpt-4.1-mini", + completion_window="24h", + batch_endpoint="/v1/chat/completions", + metadata={"pipeline": "unit"}, + credentials={"api_key": "test-key"}, + ) + adapter = OpenAIBatchAdapter() + requests = [ + adapter.build_request( + custom_id="r1", + payload={"messages": [{"role": "user", "content": "Hi"}]}, + config=config, + ) + ] + + raw_result = adapter.submit_chunk(chunk_id="chunk-01", requests=requests, config=config) + + assert captured["client_kwargs"]["api_key"] == "test-key" + assert captured["batch_create_kwargs"] == { + "input_file_id": "file_123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "metadata": {"pipeline": "unit", "chunk_id": "chunk-01"}, + } + assert raw_result["id"] == "batch_123" + assert raw_result["status"] == "validating" + assert raw_result["input_file_id"] == "file_123" + + jsonl_lines = [line for line in captured["jsonl"].split("\n") if line.strip()] + assert len(jsonl_lines) == 1 + assert json.loads(jsonl_lines[0]) == requests[0] + + +def test_openai_parse_submission_result_normalizes_payload(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + adapter = OpenAIBatchAdapter() + raw = { + "id": "batch_123", + "status": "in_progress", + "input_file_id": "file_123", + } + + result = adapter.parse_submission_result(raw_result=raw, request_count=4) + + assert isinstance(result, BatchSubmissionResult) + assert result.provider_batch_id == "batch_123" + assert result.status == "in_progress" + assert result.submitted_request_count == 4 + assert result.raw_response == raw + + +def test_factory_resolves_openai_adapter_from_registry(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + BatchAdapterRegistry.clear() + config = OpenAIBatchConfig(model="gpt-4.1-mini", credentials={"api_key": "key"}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, OpenAIBatchAdapter) From 4dd199f249f8e4a2f24c8a94af82291199599f50 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Wed, 25 Mar 2026 18:23:04 +0100 Subject: [PATCH 10/45] chunking by byte implementation + its unit tests --- src/mmirage/core/process/batch/chunking.py | 105 +++++++++++++++++ tests/test_batch_chunking.py | 131 +++++++++++++++++++++ 2 files changed, 236 insertions(+) create mode 100644 src/mmirage/core/process/batch/chunking.py create mode 100644 tests/test_batch_chunking.py diff --git a/src/mmirage/core/process/batch/chunking.py b/src/mmirage/core/process/batch/chunking.py new file mode 100644 index 0000000..b490260 --- /dev/null +++ b/src/mmirage/core/process/batch/chunking.py @@ -0,0 +1,105 @@ +"""Provider-agnostic request chunking utilities for batch submission.""" + +import logging +from dataclasses import dataclass +from typing import Any, List, Mapping, Sequence + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter + +logger = logging.getLogger(__name__) + + +@dataclass +class RequestChunk: + """Chunk of provider-ready requests with aggregate metadata.""" + + requests: List[Mapping[str, Any]] + total_bytes: int + has_oversized_request: bool = False + + @property + def total_requests(self) -> int: + return len(self.requests) + + +class BatchRequestChunker: + """Split request sequences into chunks using serialized-byte limits.""" + + def __init__(self, adapter: BatchSubmissionAdapter, config: BatchProviderConfig) -> None: + self.adapter = adapter + self.config = config + + def chunk_requests(self, requests: Sequence[Mapping[str, Any]]) -> List[RequestChunk]: + """Chunk requests according to max bytes, max requests, and oversize policy.""" + + chunks: List[RequestChunk] = [] + current_requests: List[Mapping[str, Any]] = [] + current_total_bytes = 0 + max_chunk_bytes = self.config.max_chunk_bytes + + for request in requests: + request_size = self.adapter.estimate_request_bytes(request) + + if request_size > max_chunk_bytes: + if self.config.oversized_request_policy == "reject": + raise ValueError( + "Encountered oversized request: " + f"{request_size} bytes exceeds max_chunk_bytes={max_chunk_bytes}" + ) + + logger.warning( + "Encountered oversized request (%s bytes > %s); isolating into its own chunk.", + request_size, + max_chunk_bytes, + ) + + if current_requests: + chunks.append( + RequestChunk( + requests=list(current_requests), + total_bytes=current_total_bytes, + ) + ) + current_requests = [] + current_total_bytes = 0 + + chunks.append( + RequestChunk( + requests=[request], + total_bytes=request_size, + has_oversized_request=True, + ) + ) + continue + + would_exceed_bytes = current_total_bytes + request_size > max_chunk_bytes + would_exceed_count = self._would_exceed_count_limit(current_requests) + + if current_requests and (would_exceed_bytes or would_exceed_count): + chunks.append( + RequestChunk( + requests=list(current_requests), + total_bytes=current_total_bytes, + ) + ) + current_requests = [] + current_total_bytes = 0 + + current_requests.append(request) + current_total_bytes += request_size + + if current_requests: + chunks.append( + RequestChunk( + requests=list(current_requests), + total_bytes=current_total_bytes, + ) + ) + + return chunks + + def _would_exceed_count_limit(self, current_requests: Sequence[Mapping[str, Any]]) -> bool: + if self.config.max_requests_per_chunk is None: + return False + return len(current_requests) >= self.config.max_requests_per_chunk diff --git a/tests/test_batch_chunking.py b/tests/test_batch_chunking.py new file mode 100644 index 0000000..075b97f --- /dev/null +++ b/tests/test_batch_chunking.py @@ -0,0 +1,131 @@ +import pytest + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult + + +class SizeAwareTestAdapter(BatchSubmissionAdapter): + def __init__(self) -> None: + self.estimate_calls = [] + + @property + def adapter_name(self) -> str: + return "size-aware-test-adapter" + + @property + def adapter_version(self) -> str: + return "1.0.0" + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, **dict(payload)} + + def estimate_request_bytes(self, request): + size = int(request["size_bytes"]) + self.estimate_calls.append(size) + return size + + def submit_chunk(self, chunk_id, requests, config): + return {"id": chunk_id, "status": "submitted"} + + def parse_submission_result(self, raw_result, request_count): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["id"]), + status=str(raw_result["status"]), + submitted_request_count=request_count, + raw_response=raw_result, + ) + + +def _sizes_from_chunks(chunks): + return [[request["size_bytes"] for request in chunk.requests] for chunk in chunks] + + +def test_chunker_splits_when_byte_limit_is_reached(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 4}, + {"custom_id": "r2", "size_bytes": 4}, + {"custom_id": "r3", "size_bytes": 4}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[4, 4], [4]] + assert [chunk.total_bytes for chunk in chunks] == [8, 4] + assert adapter.estimate_calls == [4, 4, 4] + + +def test_chunker_splits_when_max_requests_per_chunk_is_reached(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10_000, + max_requests_per_chunk=2, + ) + requests = [ + {"custom_id": "r1", "size_bytes": 1}, + {"custom_id": "r2", "size_bytes": 1}, + {"custom_id": "r3", "size_bytes": 1}, + {"custom_id": "r4", "size_bytes": 1}, + {"custom_id": "r5", "size_bytes": 1}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[1, 1], [1, 1], [1]] + assert [chunk.total_requests for chunk in chunks] == [2, 2, 1] + + +def test_chunker_honors_exact_byte_boundary_without_flushing_early(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 6}, + {"custom_id": "r2", "size_bytes": 4}, + {"custom_id": "r3", "size_bytes": 1}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[6, 4], [1]] + assert [chunk.total_bytes for chunk in chunks] == [10, 1] + + +def test_chunker_isolates_oversized_single_request_by_default(caplog): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 3}, + {"custom_id": "r2", "size_bytes": 25}, + {"custom_id": "r3", "size_bytes": 3}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[3], [25], [3]] + assert [chunk.has_oversized_request for chunk in chunks] == [False, True, False] + assert "oversized request" in caplog.text.lower() + + +def test_chunker_rejects_oversized_single_request_when_policy_is_reject(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + oversized_request_policy="reject", + ) + requests = [{"custom_id": "r1", "size_bytes": 11}] + + with pytest.raises(ValueError, match="oversized request"): + BatchRequestChunker(adapter, config).chunk_requests(requests) From 6964db8923267dce0da19ad8148472371a28aa12 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Wed, 25 Mar 2026 21:50:33 +0100 Subject: [PATCH 11/45] implementation of the provider-agnostic batch orchestrator and metadata logging + integrate the byte-chunking engine into the main llmprocessor + add a buffer over the iteration of map() to bridge row batches parameter of the function and the API byte limit without send non-full batches to the API + add metadata receipt generation to collect output in a asynchronous way later --- src/mmirage/config/batch_provider.py | 8 +- src/mmirage/core/process/batch/__init__.py | 5 + .../core/process/batch/orchestrator.py | 187 +++++++++++++++++ src/mmirage/core/process/mapper.py | 7 + .../core/process/processors/llm/config.py | 2 + .../process/processors/llm/llm_processor.py | 195 +++++++++++++++++- src/mmirage/shard_process.py | 2 + 7 files changed, 403 insertions(+), 3 deletions(-) create mode 100644 src/mmirage/core/process/batch/orchestrator.py diff --git a/src/mmirage/config/batch_provider.py b/src/mmirage/config/batch_provider.py index 0256787..78c399e 100644 --- a/src/mmirage/config/batch_provider.py +++ b/src/mmirage/config/batch_provider.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional @dataclass @@ -48,6 +48,9 @@ class BatchProviderConfig: chunk. If None, no request-count cap is enforced. metadata_output_path: Path where submission metadata artifacts are saved. retry_policy: Retry policy used by the shared batch layer. + oversized_request_policy: Handling policy when a single request exceeds + ``max_chunk_bytes``. ``isolate`` creates a dedicated oversized + chunk, while ``reject`` fails fast. extras: Provider-specific knobs that do not belong in the shared fields. credentials: Provider credentials required to submit chunks. """ @@ -58,6 +61,7 @@ class BatchProviderConfig: max_requests_per_chunk: Optional[int] = None metadata_output_path: str = "" retry_policy: BatchRetryPolicy = field(default_factory=BatchRetryPolicy) + oversized_request_policy: Literal["isolate", "reject"] = "isolate" extras: Dict[str, Any] = field(default_factory=dict) credentials: Dict[str, str] = field(default_factory=dict) @@ -70,3 +74,5 @@ def __post_init__(self) -> None: raise ValueError("max_chunk_bytes must be >= 1") if self.max_requests_per_chunk is not None and self.max_requests_per_chunk < 1: raise ValueError("max_requests_per_chunk must be >= 1 when provided") + if self.oversized_request_policy not in {"isolate", "reject"}: + raise ValueError("oversized_request_policy must be either 'isolate' or 'reject'") diff --git a/src/mmirage/core/process/batch/__init__.py b/src/mmirage/core/process/batch/__init__.py index fc4702a..15fe576 100644 --- a/src/mmirage/core/process/batch/__init__.py +++ b/src/mmirage/core/process/batch/__init__.py @@ -1,13 +1,18 @@ """Provider-agnostic batch processing contracts and registry.""" from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.chunking import BatchRequestChunker, RequestChunk from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter +from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry from mmirage.config.openai_batch import OpenAIBatchConfig __all__ = [ "BatchSubmissionAdapter", "BatchSubmissionResult", + "BatchRequestChunker", + "RequestChunk", + "BatchSubmissionOrchestrator", "OpenAIBatchAdapter", "OpenAIBatchConfig", "BatchAdapterFactory", diff --git a/src/mmirage/core/process/batch/orchestrator.py b/src/mmirage/core/process/batch/orchestrator.py new file mode 100644 index 0000000..c2a4f88 --- /dev/null +++ b/src/mmirage/core/process/batch/orchestrator.py @@ -0,0 +1,187 @@ +"""Stateful provider-agnostic orchestration for batch submission.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib +import json +import os +from typing import Any, Dict, List, Mapping, Optional, Sequence + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.chunking import BatchRequestChunker, RequestChunk + + +@dataclass +class _PendingRequest: + request: Mapping[str, Any] + source_index: int + + +class BatchSubmissionOrchestrator: + """Accumulate requests across map iterations and submit full-ready chunks.""" + + def __init__(self, adapter: BatchSubmissionAdapter, config: BatchProviderConfig) -> None: + self.adapter = adapter + self.config = config + self.chunker = BatchRequestChunker(adapter=adapter, config=config) + self._pending: List[_PendingRequest] = [] + self._chunk_counter = 0 + + @property + def pending_count(self) -> int: + return len(self._pending) + + def add_requests( + self, + requests: Sequence[Mapping[str, Any]], + source_indices: Sequence[int], + model_params_snapshot: Optional[Mapping[str, Any]] = None, + ) -> List[BatchSubmissionResult]: + """Append requests and submit only chunks that are ready mid-stream.""" + if len(requests) != len(source_indices): + raise ValueError("requests and source_indices must have identical lengths") + + for request, source_index in zip(requests, source_indices): + self._pending.append(_PendingRequest(request=request, source_index=source_index)) + + return self._emit_ready_chunks( + finalize=False, + model_params_snapshot=model_params_snapshot, + ) + + def finalize( + self, + model_params_snapshot: Optional[Mapping[str, Any]] = None, + ) -> List[BatchSubmissionResult]: + """Flush all remaining requests at end-of-dataset lifecycle.""" + return self._emit_ready_chunks( + finalize=True, + model_params_snapshot=model_params_snapshot, + ) + + def _emit_ready_chunks( + self, + finalize: bool, + model_params_snapshot: Optional[Mapping[str, Any]], + ) -> List[BatchSubmissionResult]: + if not self._pending: + return [] + + pending_requests = [entry.request for entry in self._pending] + chunks = self.chunker.chunk_requests(pending_requests) + chunk_groups = self._split_pending_entries_by_chunks(chunks) + + groups_to_submit: List[tuple[List[_PendingRequest], RequestChunk]] = [] + groups_to_keep: List[_PendingRequest] = [] + + if finalize: + groups_to_submit = chunk_groups + elif chunk_groups: + groups_to_submit = chunk_groups[:-1] + tail_entries, tail_chunk = chunk_groups[-1] + if self._is_complete_chunk(tail_chunk): + groups_to_submit.append((tail_entries, tail_chunk)) + else: + groups_to_keep = list(tail_entries) + + self._pending = groups_to_keep + + submission_results: List[BatchSubmissionResult] = [] + for chunk_entries, request_chunk in groups_to_submit: + chunk_id = self._next_chunk_id() + raw_result = self.adapter.submit_chunk( + chunk_id=chunk_id, + requests=[entry.request for entry in chunk_entries], + config=self.config, + ) + parsed_result = self.adapter.parse_submission_result( + raw_result=raw_result, + request_count=len(chunk_entries), + ) + submission_results.append(parsed_result) + + self._persist_metadata( + chunk_id=chunk_id, + chunk_entries=chunk_entries, + chunk=request_chunk, + parsed_result=parsed_result, + model_params_snapshot=model_params_snapshot, + flush_reason="finalize" if finalize else "full_chunk", + ) + + return submission_results + + def _split_pending_entries_by_chunks( + self, + chunks: Sequence[RequestChunk], + ) -> List[tuple[List[_PendingRequest], RequestChunk]]: + grouped: List[tuple[List[_PendingRequest], RequestChunk]] = [] + cursor = 0 + for chunk in chunks: + size = len(chunk.requests) + grouped.append((self._pending[cursor : cursor + size], chunk)) + cursor += size + return grouped + + def _is_complete_chunk(self, chunk: RequestChunk) -> bool: + if chunk.has_oversized_request: + return True + if chunk.total_bytes >= self.config.max_chunk_bytes: + return True + if self.config.max_requests_per_chunk is not None: + return chunk.total_requests >= self.config.max_requests_per_chunk + return False + + def _next_chunk_id(self) -> str: + self._chunk_counter += 1 + return f"chunk-{self._chunk_counter:06d}" + + def _persist_metadata( + self, + chunk_id: str, + chunk_entries: Sequence[_PendingRequest], + chunk: RequestChunk, + parsed_result: BatchSubmissionResult, + model_params_snapshot: Optional[Mapping[str, Any]], + flush_reason: str, + ) -> None: + if not self.config.metadata_output_path: + return + + custom_to_source = { + str(entry.request.get("custom_id", f"idx-{entry.source_index}")): entry.source_index + for entry in chunk_entries + } + + request_hash = hashlib.sha256( + json.dumps( + [entry.request for entry in chunk_entries], + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + ).hexdigest() + + metadata_record: Dict[str, Any] = { + "provider": self.config.provider, + "adapter_version": self.adapter.adapter_version, + "chunk_id": chunk_id, + "provider_batch_id": parsed_result.provider_batch_id, + "status": parsed_result.status, + "custom_id_to_source_index": custom_to_source, + "request_hash": request_hash, + "model_params_snapshot": dict(model_params_snapshot or {}), + "submitted_request_count": chunk.total_requests, + "total_bytes": chunk.total_bytes, + "has_oversized_request": chunk.has_oversized_request, + "flush_reason": flush_reason, + "submitted_at_utc": datetime.now(timezone.utc).isoformat(), + } + + metadata_path = self.config.metadata_output_path + os.makedirs(os.path.dirname(metadata_path) or ".", exist_ok=True) + with open(metadata_path, "a", encoding="utf-8") as f: + f.write(json.dumps(metadata_record, ensure_ascii=False) + "\n") diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 5310150..69573a3 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -103,3 +103,10 @@ def rewrite_batch( ) return batch_environment + + def finalize_processors(self) -> None: + """Finalize processors that expose a finalize lifecycle hook.""" + for processor in self.processors.values(): + finalize = getattr(processor, "finalize", None) + if callable(finalize): + finalize() diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 859552d..762d095 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -77,11 +77,13 @@ class SGLangLLMConfig(BaseProcessorConfig): server_args: SGLang server arguments including model path and TP size. default_sampling_params: Default sampling parameters for generation. chat_template: Chat template name for vision-language models (e.g., "qwen2-vl"). + batch_provider: Optional provider batch settings for async submission. """ server_args: SGLangServerArgs = field(default_factory=SGLangServerArgs) default_sampling_params: Dict[str, Any] = field(default_factory=dict) chat_template: str = "" # Empty means use tokenizer's default + batch_provider: Dict[str, Any] = field(default_factory=dict) @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 56afd17..08fc7ec 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -2,18 +2,21 @@ from __future__ import annotations -from dataclasses import asdict +from dataclasses import asdict, replace import json import logging -from typing import Any, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jinja2 import sglang as sgl from transformers import AutoTokenizer from mmirage.core.process.base import BaseProcessor, ProcessorRegistry +from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator +from mmirage.core.process.batch.registry import BatchAdapterFactory from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment +from mmirage.config.openai_batch import OpenAIBatchConfig try: from typing import override # Python 3.12+ @@ -65,6 +68,65 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: ) self.sampling_params = engine_args.default_sampling_params self.chat_template = engine_args.chat_template + self._batch_adapter = None + self._batch_provider_config = None + self._text_orchestrator: Optional[BatchSubmissionOrchestrator] = None + self._multimodal_orchestrator: Optional[BatchSubmissionOrchestrator] = None + self._batch_request_counter = 0 + self._setup_batch_runtime() + + def _setup_batch_runtime(self) -> None: + provider_cfg_raw = dict(getattr(self.config, "batch_provider", {}) or {}) + if not provider_cfg_raw: + return + + if not provider_cfg_raw.get("enabled", False): + return + + provider = str(provider_cfg_raw.get("provider", "openai")).strip().lower() + if provider != "openai": + raise ValueError( + f"Only provider='openai' is currently supported, got '{provider}'." + ) + + openai_cfg = OpenAIBatchConfig(**provider_cfg_raw) + self._batch_provider_config = openai_cfg + self._batch_adapter = BatchAdapterFactory.from_config(openai_cfg) + + self._text_orchestrator = BatchSubmissionOrchestrator( + adapter=self._batch_adapter, + config=replace( + openai_cfg, + metadata_output_path=self._with_metadata_suffix( + openai_cfg.metadata_output_path, "text" + ), + ), + ) + self._multimodal_orchestrator = BatchSubmissionOrchestrator( + adapter=self._batch_adapter, + config=replace( + openai_cfg, + metadata_output_path=self._with_metadata_suffix( + openai_cfg.metadata_output_path, "multimodal" + ), + ), + ) + + @staticmethod + def _with_metadata_suffix(path: str, suffix: str) -> str: + if not path: + return "" + if path.endswith(".jsonl"): + return path[:-6] + f".{suffix}.jsonl" + return f"{path}.{suffix}.jsonl" + + @property + def batch_mode_enabled(self) -> bool: + return self._text_orchestrator is not None and self._multimodal_orchestrator is not None + + def _next_custom_id(self, output_name: str, global_index: int, modality: str) -> str: + self._batch_request_counter += 1 + return f"{output_name}:{modality}:{self._batch_request_counter}:{global_index}" def build_prompt( self, prompt_template: str, vars_samples: List[VariableEnvironment] @@ -148,6 +210,9 @@ def batch_process_sample( """ nb_samples = len(batch) + if self.batch_mode_enabled: + return self._batch_process_sample(batch=batch, output_var=output_var) + # Prepare sampling params sampling_params_output = self.sampling_params.copy() @@ -271,6 +336,132 @@ def batch_process_sample( return [results[i] for i in range(nb_samples)] + def _batch_process_sample( + self, + batch: List[VariableEnvironment], + output_var: LLMOutputVar, + ) -> List[VariableEnvironment]: + assert self._batch_provider_config is not None + assert self._batch_adapter is not None + assert self._text_orchestrator is not None + assert self._multimodal_orchestrator is not None + + nb_samples = len(batch) + text_only_indices: List[int] = [] + multimodal_indices: List[int] = [] + for i in range(nb_samples): + if batch[i].has_images(): + multimodal_indices.append(i) + else: + text_only_indices.append(i) + + if text_only_indices: + text_only_envs = [batch[i] for i in text_only_indices] + prompts = self.build_prompt(output_var.prompt, text_only_envs) + requests: List[Dict[str, Any]] = [] + source_indices: List[int] = [] + for local_i, global_i in enumerate(text_only_indices): + payload = { + "messages": [ + { + "role": "user", + "content": prompts[local_i], + } + ] + } + custom_id = self._next_custom_id(output_var.name, global_i, "text") + request = self._batch_adapter.build_request( + custom_id=custom_id, + payload=payload, + config=self._batch_provider_config, + ) + requests.append(dict(request)) + source_indices.append(global_i) + + self._text_orchestrator.add_requests( + requests=requests, + source_indices=source_indices, + model_params_snapshot={ + "output_name": output_var.name, + "output_type": output_var.output_type, + "modality": "text", + }, + ) + + if multimodal_indices: + requests = [] + source_indices = [] + for global_i in multimodal_indices: + base_prompt, image_data = self.build_multimodal_prompt(output_var.prompt, batch[global_i]) + content: List[Dict[str, Any]] = [{"type": "text", "text": base_prompt}] + + if image_data is not None: + if isinstance(image_data, list): + images = image_data + else: + images = [image_data] + for image_ref in images: + content.append( + { + "type": "image_url", + "image_url": {"url": str(image_ref)}, + } + ) + + payload = { + "messages": [ + { + "role": "user", + "content": content, + } + ] + } + custom_id = self._next_custom_id(output_var.name, global_i, "multimodal") + request = self._batch_adapter.build_request( + custom_id=custom_id, + payload=payload, + config=self._batch_provider_config, + ) + requests.append(dict(request)) + source_indices.append(global_i) + + self._multimodal_orchestrator.add_requests( + requests=requests, + source_indices=source_indices, + model_params_snapshot={ + "output_name": output_var.name, + "output_type": output_var.output_type, + "modality": "multimodal", + }, + ) + + placeholders: List[VariableEnvironment] = [] + for i in range(nb_samples): + placeholder = f"__BATCH_SUBMITTED__:{output_var.name}:{i}" + placeholders.append(batch[i].with_variable(output_var.name, placeholder)) + + return placeholders + + def finalize(self) -> None: + if not self.batch_mode_enabled: + return + + assert self._text_orchestrator is not None + assert self._multimodal_orchestrator is not None + + self._text_orchestrator.finalize( + model_params_snapshot={ + "modality": "text", + "phase": "finalize", + } + ) + self._multimodal_orchestrator.finalize( + model_params_snapshot={ + "modality": "multimodal", + "phase": "finalize", + } + ) + def shutdown(self) -> None: """Shutdown the LLM engine.""" try: diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index f232dc6..a656e79 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -142,6 +142,8 @@ def main(): fn_kwargs={"mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path}, remove_columns=remove_columns, ) + # Drain stateful batch accumulators once this dataset map iteration finishes. + mapper.finalize_processors() ds_processed_all.append(ds_processed) for ds_config, ds_processed in zip(datasets_config, ds_processed_all): From 77e39e406f02aa2fe7bbffee644d411ea4be714e Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 02:12:50 +0100 Subject: [PATCH 12/45] unit tests + corrections --- src/mmirage/core/process/batch/adapter.py | 38 +++++ .../core/process/batch/openai_adapter.py | 134 ++++++++++++++- .../process/processors/llm/llm_processor.py | 13 +- tests/test_batch_adapter_contracts.py | 11 ++ tests/test_batch_chunking.py | 11 ++ tests/test_batch_orchestrator.py | 123 ++++++++++++++ tests/test_integration_batch_pipeline.py | 156 ++++++++++++++++++ tests/test_openai_batch_adapter.py | 131 +++++++++++++++ 8 files changed, 611 insertions(+), 6 deletions(-) create mode 100644 tests/test_batch_orchestrator.py create mode 100644 tests/test_integration_batch_pipeline.py diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py index 29ff1e1..1a71805 100644 --- a/src/mmirage/core/process/batch/adapter.py +++ b/src/mmirage/core/process/batch/adapter.py @@ -131,3 +131,41 @@ def parse_submission_result( orchestration and metadata persistence. """ raise NotImplementedError() + + @abc.abstractmethod + def check_batch_status( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> BatchSubmissionResult: + """Retrieve and normalize the latest status for a provider batch job. + + Args: + provider_batch_id: Provider-side batch/job identifier to query. + config: Provider configuration containing credentials and endpoint + overrides. + + Returns: + A normalized ``BatchSubmissionResult`` where ``status`` reflects the + latest provider-reported lifecycle state for the batch. + """ + raise NotImplementedError() + + @abc.abstractmethod + def retrieve_results( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> Sequence[Dict[str, Any]]: + """Download and parse completed batch results from the provider. + + Args: + provider_batch_id: Provider-side batch/job identifier. + config: Provider configuration containing credentials and endpoint + overrides. + + Returns: + Sequence of parsed result rows (provider JSONL records normalized to + dictionaries) preserving provider output order. + """ + raise NotImplementedError() diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 2046d06..cbc85b9 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -2,7 +2,7 @@ import io import json -from typing import Any, Mapping, Sequence +from typing import Any, Dict, List, Mapping, Sequence from openai import OpenAI @@ -53,10 +53,7 @@ def submit_chunk( ) -> Mapping[str, Any]: openai_config = self._as_openai_config(config) - client_kwargs = {"api_key": openai_config.credentials.get("api_key", "")} - if openai_config.base_url: - client_kwargs["base_url"] = openai_config.base_url - client = OpenAI(**client_kwargs) + client = self._create_client(openai_config) jsonl_lines = [ json.dumps(req, ensure_ascii=False, separators=(",", ":")) for req in requests @@ -86,6 +83,61 @@ def submit_chunk( "chunk_id": chunk_id, } + def check_batch_status( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> BatchSubmissionResult: + openai_config = self._as_openai_config(config) + client = self._create_client(openai_config) + + retrieved = client.batches.retrieve(provider_batch_id) + raw_result = { + "id": self._read_attr(retrieved, "id"), + "status": self._read_attr(retrieved, "status"), + } + return self.parse_submission_result(raw_result=raw_result, request_count=0) + + def retrieve_results( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> Sequence[Dict[str, Any]]: + openai_config = self._as_openai_config(config) + client = self._create_client(openai_config) + + retrieved = client.batches.retrieve(provider_batch_id) + status = str(self._read_attr(retrieved, "status") or "unknown") + output_file_id = self._read_attr(retrieved, "output_file_id") + + if status != "completed" or not output_file_id: + raise ValueError( + f"Batch '{provider_batch_id}' is not completed or has no output file (status={status})." + ) + + content_response = client.files.content(output_file_id) + jsonl_text = self._extract_content_text(content_response) + + rows: List[Dict[str, Any]] = [] + for line in jsonl_text.splitlines(): + raw = line.strip() + if not raw: + continue + + parsed = dict(json.loads(raw)) + custom_id = str(parsed.get("custom_id", "")).strip() + if not custom_id: + continue + + rows.append( + { + "custom_id": custom_id, + "generated_text": self._extract_generated_text(parsed), + } + ) + + return rows + def parse_submission_result( self, raw_result: Mapping[str, Any], @@ -107,6 +159,78 @@ def _as_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: return config raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") + @staticmethod + def _create_client(config: OpenAIBatchConfig) -> OpenAI: + client_kwargs = {"api_key": config.credentials.get("api_key", "")} + if config.base_url: + client_kwargs["base_url"] = config.base_url + return OpenAI(**client_kwargs) + + @staticmethod + def _extract_content_text(content_response: Any) -> str: + text = getattr(content_response, "text", None) + if isinstance(text, str): + return text + + read = getattr(content_response, "read", None) + if callable(read): + data = read() + if isinstance(data, bytes): + return data.decode("utf-8") + if isinstance(data, str): + return data + + content = getattr(content_response, "content", None) + if isinstance(content, bytes): + return content.decode("utf-8") + if isinstance(content, str): + return content + + raise ValueError("Unable to parse OpenAI files.content response payload.") + + @staticmethod + def _extract_generated_text(row: Mapping[str, Any]) -> str: + response = row.get("response") + if not isinstance(response, Mapping): + return "" + + body = response.get("body") + if not isinstance(body, Mapping): + return "" + + # Backward-compatible fallback for simplified fixtures/providers. + body_text = body.get("text") + if isinstance(body_text, str): + return body_text + + choices = body.get("choices") + if not isinstance(choices, list) or not choices: + return "" + + first_choice = choices[0] + if not isinstance(first_choice, Mapping): + return "" + + message = first_choice.get("message") + if not isinstance(message, Mapping): + return "" + + content = message.get("content") + if isinstance(content, str): + return content + + # Some OpenAI-compatible responses return segmented content parts. + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if not isinstance(item, Mapping): + continue + if item.get("type") == "text" and isinstance(item.get("text"), str): + parts.append(item["text"]) + return "".join(parts) + + return "" + @staticmethod def _read_attr(obj: Any, key: str) -> Any: if isinstance(obj, Mapping): diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 08fc7ec..1108ca3 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -61,7 +61,15 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: **kwargs: Additional arguments passed to base class. """ super().__init__(engine_args, **kwargs) - self.llm = sgl.Engine(**asdict(engine_args.server_args)) + provider_cfg_raw = dict(getattr(engine_args, "batch_provider", {}) or {}) + batch_mode_requested = bool(provider_cfg_raw.get("enabled", False)) + + # In provider-batch mode we only build payloads/metadata and should not + # initialize GPU-backed SGLang runtime. + self.llm = None + if not batch_mode_requested: + self.llm = sgl.Engine(**asdict(engine_args.server_args)) + self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), @@ -464,6 +472,9 @@ def finalize(self) -> None: def shutdown(self) -> None: """Shutdown the LLM engine.""" + if self.llm is None: + return + try: self.llm.shutdown() except Exception as e: diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py index ce8a520..3072344 100644 --- a/tests/test_batch_adapter_contracts.py +++ b/tests/test_batch_adapter_contracts.py @@ -38,6 +38,17 @@ def parse_submission_result(self, raw_result, request_count): raw_response=raw_result, ) + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + submitted_request_count=0, + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + class CredentialedTestAdapter(CompleteTestAdapter): required_credentials = ("api_key",) diff --git a/tests/test_batch_chunking.py b/tests/test_batch_chunking.py index 075b97f..753cd9a 100644 --- a/tests/test_batch_chunking.py +++ b/tests/test_batch_chunking.py @@ -35,6 +35,17 @@ def parse_submission_result(self, raw_result, request_count): raw_response=raw_result, ) + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + submitted_request_count=0, + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + def _sizes_from_chunks(chunks): return [[request["size_bytes"] for request in chunk.requests] for chunk in chunks] diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py new file mode 100644 index 0000000..f52afdf --- /dev/null +++ b/tests/test_batch_orchestrator.py @@ -0,0 +1,123 @@ +import json + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult + + +class RecordingAdapter(BatchSubmissionAdapter): + def __init__(self) -> None: + self.submissions = [] + + @property + def adapter_name(self) -> str: + return "recording-adapter" + + @property + def adapter_version(self) -> str: + return "1.2.3" + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, **dict(payload)} + + def estimate_request_bytes(self, request): + return int(request["size_bytes"]) + + def submit_chunk(self, chunk_id, requests, config): + self.submissions.append( + { + "chunk_id": chunk_id, + "requests": list(requests), + } + ) + return {"id": f"batch-{chunk_id}", "status": "submitted"} + + def parse_submission_result(self, raw_result, request_count): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["id"]), + status=str(raw_result["status"]), + submitted_request_count=request_count, + raw_response=raw_result, + ) + + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + submitted_request_count=0, + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + + +def test_orchestrator_buffers_across_iterations_and_avoids_tiny_midstream_flush(tmp_path): + from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator + + adapter = RecordingAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + metadata_output_path=str(tmp_path / "metadata.jsonl"), + ) + orchestrator = BatchSubmissionOrchestrator(adapter=adapter, config=config) + + # Iteration 1: only 9 bytes total, should remain buffered and submit nothing. + r1 = [{"custom_id": "a", "size_bytes": 6}, {"custom_id": "b", "size_bytes": 3}] + out1 = orchestrator.add_requests(r1, [10, 11], {"phase": "iter1"}) + assert out1 == [] + assert len(adapter.submissions) == 0 + + # Iteration 2: appending 2 bytes should emit one full chunk [6,3] and keep [2]. + r2 = [{"custom_id": "c", "size_bytes": 2}] + out2 = orchestrator.add_requests(r2, [12], {"phase": "iter2"}) + assert len(out2) == 1 + assert len(adapter.submissions) == 1 + assert [x["size_bytes"] for x in adapter.submissions[0]["requests"]] == [6, 3] + assert orchestrator.pending_count == 1 + + # Finalize: emits the remaining tiny tail exactly once. + out3 = orchestrator.finalize({"phase": "finalize"}) + assert len(out3) == 1 + assert len(adapter.submissions) == 2 + assert [x["size_bytes"] for x in adapter.submissions[1]["requests"]] == [2] + assert orchestrator.pending_count == 0 + + +def test_orchestrator_writes_provider_neutral_metadata_with_flush_reason(tmp_path): + from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator + + metadata_path = tmp_path / "batch_metadata.jsonl" + adapter = RecordingAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + metadata_output_path=str(metadata_path), + ) + orchestrator = BatchSubmissionOrchestrator(adapter=adapter, config=config) + + orchestrator.add_requests( + requests=[ + {"custom_id": "x1", "size_bytes": 8}, + {"custom_id": "x2", "size_bytes": 8}, + ], + source_indices=[0, 1], + model_params_snapshot={"model": "unit-model"}, + ) + orchestrator.finalize({"model": "unit-model"}) + + lines = metadata_path.read_text(encoding="utf-8").strip().splitlines() + assert len(lines) == 2 + + first = json.loads(lines[0]) + second = json.loads(lines[1]) + + assert first["provider"] == "unit" + assert first["adapter_version"] == "1.2.3" + assert first["flush_reason"] == "full_chunk" + assert first["custom_id_to_source_index"] == {"x1": 0} + assert isinstance(first["request_hash"], str) and len(first["request_hash"]) == 64 + + assert second["flush_reason"] == "finalize" + assert second["custom_id_to_source_index"] == {"x2": 1} + assert second["provider_batch_id"].startswith("batch-chunk-") diff --git a/tests/test_integration_batch_pipeline.py b/tests/test_integration_batch_pipeline.py new file mode 100644 index 0000000..86351d1 --- /dev/null +++ b/tests/test_integration_batch_pipeline.py @@ -0,0 +1,156 @@ +import json +from pathlib import Path + +from datasets import load_dataset + +from mmirage.core.process import LLMProcessor # Ensures processor registration. +from mmirage.core.process.mapper import MMIRAGEMapper +from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig, SGLangServerArgs +from mmirage.core.process.variables import InputVar +from mmirage.core.writer.renderer import TemplateRenderer + + +def test_integration_batch_pipeline_with_stateful_accumulator(monkeypatch, tmp_path): + captured = { + "file_uploads": [], + "batch_creates": [], + "engine_init_calls": 0, + } + + class FakeFiles: + def create(self, *, file, purpose): + file_name, file_obj = file + payload = file_obj.read().decode("utf-8") + captured["file_uploads"].append( + { + "file_name": file_name, + "purpose": purpose, + "payload": payload, + } + ) + + class _FileResp: + id = f"file_{len(captured['file_uploads'])}" + + return _FileResp() + + class FakeBatches: + def create(self, **kwargs): + captured["batch_creates"].append(kwargs) + + class _BatchResp: + id = f"batch_{len(captured['batch_creates'])}" + status = "validating" + endpoint = kwargs["endpoint"] + + return _BatchResp() + + class FakeOpenAIClient: + def __init__(self, **_kwargs): + self.files = FakeFiles() + self.batches = FakeBatches() + + class FakeEngine: + def __init__(self, **_kwargs): + captured["engine_init_calls"] += 1 + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation path should not run in batch mode") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, user_prompt, tokenize=False, add_generation_prompt=True): + assert tokenize is False + assert add_generation_prompt is True + return user_prompt[0]["content"] + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeOpenAIClient, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + metadata_base = tmp_path / "batch_receipts.jsonl" + llm_cfg = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider={ + "enabled": True, + "provider": "openai", + "model": "gpt-4.1-mini", + "max_chunk_bytes": 500, + "max_requests_per_chunk": None, + "metadata_output_path": str(metadata_base), + "credentials": {"api_key": "test-key"}, + "metadata": {"pipeline": "integration-test"}, + }, + ) + + mapper = MMIRAGEMapper( + processor_configs=[llm_cfg], + input_vars=[InputVar(name="text", key="text")], + output_vars=[ + LLMOutputVar( + name="answer", + type="llm", + prompt="{{ text }}", + output_type="plain", + ) + ], + ) + renderer = TemplateRenderer(output_schema={"answer": "{{ answer }}"}) + + data_path = Path(__file__).parent / "mock_data" / "data.jsonl" + dataset = load_dataset("json", data_files=str(data_path), split="train") + + def rewrite_batch(batch, mapper, renderer): + envs = mapper.rewrite_batch(batch) + return renderer.batch_render(envs) + + ds_out = dataset.map( + rewrite_batch, + batched=True, + batch_size=7, + fn_kwargs={"mapper": mapper, "renderer": renderer}, + load_from_cache_file=False, + ) + + # Explicit lifecycle flush required by the architecture. + mapper.finalize_processors() + + # 1) Multiple provider submissions prove byte-based chunking with carry-over. + assert captured["engine_init_calls"] == 0 + assert len(captured["file_uploads"]) > 1 + assert len(captured["batch_creates"]) > 1 + + # 2) Map output is placeholder-based and does not wait for completion. + answers = ds_out["answer"] + assert len(answers) == len(dataset) + assert all(isinstance(v, str) and v.startswith("__BATCH_SUBMITTED__:answer:") for v in answers) + + # 3) Metadata receipts are written and include both full_chunk and finalize flush reasons. + metadata_text_path = tmp_path / "batch_receipts.text.jsonl" + assert metadata_text_path.exists() + + records = [ + json.loads(line) + for line in metadata_text_path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + assert len(records) > 1 + + flush_reasons = {record["flush_reason"] for record in records} + assert "full_chunk" in flush_reasons + assert "finalize" in flush_reasons + + assert all(record["provider"] == "openai" for record in records) + assert all("custom_id_to_source_index" in record for record in records) diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index 59caeeb..e436471 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -151,3 +151,134 @@ def test_factory_resolves_openai_adapter_from_registry(): adapter = BatchAdapterFactory.from_config(config) assert isinstance(adapter, OpenAIBatchAdapter) + + +def test_openai_check_batch_status_uses_mocked_openai_client(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + captured["retrieved_id"] = provider_batch_id + + class _RetrieveResp: + id = provider_batch_id + status = "completed" + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig( + credentials={"api_key": "test-key"}, + base_url="https://example.test/v1", + ) + adapter = OpenAIBatchAdapter() + + result = adapter.check_batch_status(provider_batch_id="batch_456", config=config) + + assert captured["client_kwargs"] == { + "api_key": "test-key", + "base_url": "https://example.test/v1", + } + assert captured["retrieved_id"] == "batch_456" + assert isinstance(result, BatchSubmissionResult) + assert result.provider_batch_id == "batch_456" + assert result.status == "completed" + assert result.submitted_request_count == 0 + + +def test_openai_retrieve_results_downloads_and_parses_jsonl(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + captured["retrieved_id"] = provider_batch_id + + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_1" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + captured["output_file_id"] = output_file_id + + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"text":"A"}}}\n' + '{"custom_id":"c2","response":{"body":{"text":"B"}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + assert captured["retrieved_id"] == "batch_abc" + assert captured["output_file_id"] == "file_output_1" + assert len(rows) == 2 + assert rows[0]["custom_id"] == "c1" + assert rows[1]["custom_id"] == "c2" + assert rows[0]["generated_text"] == "A" + assert rows[1]["generated_text"] == "B" + + +def test_openai_retrieve_results_raises_if_batch_not_completed(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "in_progress" + output_file_id = None + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + + class _Files: + def content(self, output_file_id): + raise AssertionError("content() should not be called when batch is not completed") + + self.files = _Files() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + with pytest.raises(ValueError, match="not completed"): + adapter.retrieve_results(provider_batch_id="batch_abc", config=config) From a5f9b623c9d0fde1e9c5bb0287c2ae181e8bd27d Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 05:06:17 +0100 Subject: [PATCH 13/45] implementation of the status checker , and the collector + unit tests + config mock for text and vision + some corrections --- configs/config_mock_openai_batch.yaml | 57 ++++++ configs/config_mock_openai_batch_vision.yaml | 47 +++++ src/mmirage/core/process/batch/__init__.py | 8 + src/mmirage/core/process/batch/collector.py | 189 ++++++++++++++++++ .../core/process/batch/openai_adapter.py | 85 +++----- src/mmirage/core/process/batch/registry.py | 19 +- .../core/process/batch/status_checker.py | 137 +++++++++++++ .../process/processors/llm/llm_processor.py | 4 + tests/test_batch_adapter_contracts.py | 11 + tests/test_batch_collector.py | 144 +++++++++++++ tests/test_batch_status_checker.py | 88 ++++++++ tests/test_integration_receiver.py | 101 ++++++++++ tests/test_openai_batch_adapter.py | 79 +++++++- 13 files changed, 908 insertions(+), 61 deletions(-) create mode 100644 configs/config_mock_openai_batch.yaml create mode 100644 configs/config_mock_openai_batch_vision.yaml create mode 100644 src/mmirage/core/process/batch/collector.py create mode 100644 src/mmirage/core/process/batch/status_checker.py create mode 100644 tests/test_batch_collector.py create mode 100644 tests/test_batch_status_checker.py create mode 100644 tests/test_integration_receiver.py diff --git a/configs/config_mock_openai_batch.yaml b/configs/config_mock_openai_batch.yaml new file mode 100644 index 0000000..80f5bd8 --- /dev/null +++ b/configs/config_mock_openai_batch.yaml @@ -0,0 +1,57 @@ +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B-Instruct-2507 + tp_size: 1 + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 1024 + custom_params: + chat_template_kwargs: + enable_thinking: false + batch_provider: + enabled: true + provider: openai + model: gpt-4o-mini + max_chunk_bytes: 52428800 + metadata_output_path: tests/output/batch_metadata.jsonl + credentials: + api_key: + +loading_params: + datasets: + - path: tests/mock_data/data.jsonl + type: JSONL + output_dir: tests/output/data_openai_batch + + num_shards: 1 + shard_id: 0 + batch_size: 64 + +processing_params: + inputs: + - name: text + key: text + + outputs: + - name: formatted_answer + type: llm + output_type: JSON + output_schema: + - question + - answer + prompt: | + Generate one question and its corresponding answer using the following text: + ``` + {{ text }} + ``` + + remove_columns: true + output_schema: + conversations: + - role: "user" + content: "{{ formatted_answer.question }}" + - role: "assistant" + content: "{{ formatted_answer.answer }}" \ No newline at end of file diff --git a/configs/config_mock_openai_batch_vision.yaml b/configs/config_mock_openai_batch_vision.yaml new file mode 100644 index 0000000..e89b571 --- /dev/null +++ b/configs/config_mock_openai_batch_vision.yaml @@ -0,0 +1,47 @@ +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-VL-8B-Instruct + tp_size: 1 + trust_remote_code: true + chat_template: qwen2-vl + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 512 + batch_provider: + enabled: true + provider: openai + model: gpt-4o-mini + metadata_output_path: tests/output/batch_metadata_vision.jsonl + credentials: + api_key: "" + +loading_params: + datasets: + - path: tests/mock_data_vision/data.jsonl + type: JSONL + output_dir: tests/output/data_openai_batch_vision + image_base_path: tests/mock_data_vision + + num_shards: 1 + shard_id: 0 + batch_size: 1 + +processing_params: + inputs: + - name: image_input + key: image + type: image + + outputs: + - name: caption + type: llm + output_type: plain + prompt: | + Describe what you see in this image in one concise sentence. + + remove_columns: false + output_schema: + image: "{{ image_input }}" + caption: "{{ caption }}" \ No newline at end of file diff --git a/src/mmirage/core/process/batch/__init__.py b/src/mmirage/core/process/batch/__init__.py index 15fe576..31334ba 100644 --- a/src/mmirage/core/process/batch/__init__.py +++ b/src/mmirage/core/process/batch/__init__.py @@ -1,15 +1,21 @@ """Provider-agnostic batch processing contracts and registry.""" from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.collector import collect_and_merge from mmirage.core.process.batch.chunking import BatchRequestChunker, RequestChunk from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry +from mmirage.core.process.batch.status_checker import ( + extract_unique_provider_batches, + run_status_checker, +) from mmirage.config.openai_batch import OpenAIBatchConfig __all__ = [ "BatchSubmissionAdapter", "BatchSubmissionResult", + "collect_and_merge", "BatchRequestChunker", "RequestChunk", "BatchSubmissionOrchestrator", @@ -17,4 +23,6 @@ "OpenAIBatchConfig", "BatchAdapterFactory", "BatchAdapterRegistry", + "extract_unique_provider_batches", + "run_status_checker", ] diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py new file mode 100644 index 0000000..6e4df8f --- /dev/null +++ b/src/mmirage/core/process/batch/collector.py @@ -0,0 +1,189 @@ +"""Receiver-side utility for collecting provider results and merging by source row index.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.registry import BatchAdapterFactory + + +def _read_metadata_records(metadata_output_path: str) -> List[Dict[str, Any]]: + records: List[Dict[str, Any]] = [] + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + records.append(parsed) + return records + + +def _aggregate_batch_mappings( + records: Sequence[Mapping[str, Any]], +) -> Dict[Tuple[str, str], Dict[str, int]]: + aggregated: Dict[Tuple[str, str], Dict[str, int]] = {} + + for record in records: + provider = str(record.get("provider", "")).strip().lower() + provider_batch_id = str(record.get("provider_batch_id", "")).strip() + mapping = record.get("custom_id_to_source_index", {}) + + if not provider or not provider_batch_id or not isinstance(mapping, dict): + continue + + key = (provider, provider_batch_id) + if key not in aggregated: + aggregated[key] = {} + + for custom_id, source_index in mapping.items(): + try: + aggregated[key][str(custom_id)] = int(source_index) + except (TypeError, ValueError): + continue + + return aggregated + + +def collect_and_merge( + metadata_output_path: str, + provider_configs: Mapping[str, BatchProviderConfig], + output_path: str, +) -> List[Dict[str, Any]]: + """Collect completed results and reconstruct rows in source index order.""" + records = _read_metadata_records(metadata_output_path) + pair_to_mapping = _aggregate_batch_mappings(records) + + adapters: Dict[str, Any] = {} + pair_to_results: Dict[Tuple[str, str], Sequence[Dict[str, Any]]] = {} + + for provider, provider_batch_id in pair_to_mapping.keys(): + if provider not in provider_configs: + raise ValueError(f"No provider config found for '{provider}'.") + + if provider not in adapters: + adapters[provider] = BatchAdapterFactory.from_config(provider_configs[provider]) + + pair = (provider, provider_batch_id) + pair_to_results[pair] = adapters[provider].retrieve_results( + provider_batch_id=provider_batch_id, + config=provider_configs[provider], + ) + + indexed_rows: MutableMapping[int, Dict[str, Any]] = {} + for pair, mapping in pair_to_mapping.items(): + results = pair_to_results.get(pair, []) + for result_row in results: + custom_id = str(result_row.get("custom_id", "")).strip() + if not custom_id or custom_id not in mapping: + continue + source_index = mapping[custom_id] + parsed = _parse_structured_content(result_row) + indexed_rows[source_index] = { + "source_index": source_index, + "custom_id": custom_id, + "conversations": [ + { + "role": "user", + "content": str(parsed.get("question", "")), + }, + { + "role": "assistant", + "content": str(parsed.get("answer", "")), + }, + ], + } + + ordered_rows = [indexed_rows[idx] for idx in sorted(indexed_rows.keys())] + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + for row in ordered_rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + return ordered_rows + + +def _parse_structured_content(result_row: Mapping[str, Any]) -> Dict[str, Any]: + # Preferred OpenAI envelope path for Structured Outputs. + response = result_row.get("response") + if isinstance(response, Mapping): + body = response.get("body") + if isinstance(body, Mapping): + choices = body.get("choices") + if isinstance(choices, list) and choices: + first_choice = choices[0] + if isinstance(first_choice, Mapping): + message = first_choice.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + try: + parsed = json.loads(content) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + return {} + + # Fallback for normalized adapter payloads carrying generated_text directly. + generated_text = result_row.get("generated_text") + if isinstance(generated_text, str): + try: + parsed = json.loads(generated_text) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + return {} + + return {} + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Collect provider batch outputs and merge rows by source index." + ) + parser.add_argument( + "--metadata-path", + required=True, + help="Path to metadata JSONL receipt file.", + ) + parser.add_argument( + "--output-path", + required=True, + help="Path to write merged JSONL output.", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + args = _build_arg_parser().parse_args(argv) + + api_key = os.environ.get("OPENAI_API_KEY", "").strip() + if not api_key: + raise ValueError("OPENAI_API_KEY is required for collector execution.") + + provider_configs: Dict[str, BatchProviderConfig] = { + "openai": OpenAIBatchConfig(credentials={"api_key": api_key}) + } + + rows = collect_and_merge(args.metadata_path, provider_configs, args.output_path) + print(f"Merged {len(rows)} rows and saved to {args.output_path}") + return 0 + + +if __name__ == "__main__": + try: + raise SystemExit(main()) + except Exception as exc: + print(f"Collector failed: {exc}", file=sys.stderr) + raise SystemExit(1) diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index cbc85b9..5430622 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -2,6 +2,7 @@ import io import json +import os from typing import Any, Dict, List, Mapping, Sequence from openai import OpenAI @@ -32,8 +33,25 @@ def build_request( ) -> Mapping[str, Any]: openai_config = self._as_openai_config(config) body = dict(payload) + expected_schema = body.pop("expected_schema", None) body.setdefault("model", openai_config.model) + if isinstance(expected_schema, list) and all(isinstance(k, str) for k in expected_schema): + properties = {key: {"type": "string"} for key in expected_schema} + body["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "strict": True, + "schema": { + "type": "object", + "properties": properties, + "required": expected_schema, + "additionalProperties": False, + }, + }, + } + return { "custom_id": custom_id, "method": "POST", @@ -123,18 +141,7 @@ def retrieve_results( raw = line.strip() if not raw: continue - - parsed = dict(json.loads(raw)) - custom_id = str(parsed.get("custom_id", "")).strip() - if not custom_id: - continue - - rows.append( - { - "custom_id": custom_id, - "generated_text": self._extract_generated_text(parsed), - } - ) + rows.append(dict(json.loads(raw))) return rows @@ -161,7 +168,16 @@ def _as_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: @staticmethod def _create_client(config: OpenAIBatchConfig) -> OpenAI: - client_kwargs = {"api_key": config.credentials.get("api_key", "")} + api_key = (config.credentials.get("api_key", "") or "").strip() + if not api_key: + api_key = os.environ.get("OPENAI_API_KEY", "").strip() + + if not api_key: + raise ValueError( + "OpenAI API key is missing. Provide credentials.api_key or set OPENAI_API_KEY." + ) + + client_kwargs = {"api_key": api_key} if config.base_url: client_kwargs["base_url"] = config.base_url return OpenAI(**client_kwargs) @@ -188,49 +204,6 @@ def _extract_content_text(content_response: Any) -> str: raise ValueError("Unable to parse OpenAI files.content response payload.") - @staticmethod - def _extract_generated_text(row: Mapping[str, Any]) -> str: - response = row.get("response") - if not isinstance(response, Mapping): - return "" - - body = response.get("body") - if not isinstance(body, Mapping): - return "" - - # Backward-compatible fallback for simplified fixtures/providers. - body_text = body.get("text") - if isinstance(body_text, str): - return body_text - - choices = body.get("choices") - if not isinstance(choices, list) or not choices: - return "" - - first_choice = choices[0] - if not isinstance(first_choice, Mapping): - return "" - - message = first_choice.get("message") - if not isinstance(message, Mapping): - return "" - - content = message.get("content") - if isinstance(content, str): - return content - - # Some OpenAI-compatible responses return segmented content parts. - if isinstance(content, list): - parts: List[str] = [] - for item in content: - if not isinstance(item, Mapping): - continue - if item.get("type") == "text" and isinstance(item.get("text"), str): - parts.append(item["text"]) - return "".join(parts) - - return "" - @staticmethod def _read_attr(obj: Any, key: str) -> Any: if isinstance(obj, Mapping): diff --git a/src/mmirage/core/process/batch/registry.py b/src/mmirage/core/process/batch/registry.py index 4267375..862585f 100644 --- a/src/mmirage/core/process/batch/registry.py +++ b/src/mmirage/core/process/batch/registry.py @@ -1,5 +1,6 @@ """Registry and factory for provider batch adapters.""" +import os from typing import Dict, Type from mmirage.config.batch_provider import BatchProviderConfig @@ -60,9 +61,21 @@ def resolve(cls, provider: str) -> Type[BatchSubmissionAdapter]: def create(cls, config: BatchProviderConfig) -> BatchSubmissionAdapter: """Instantiate an adapter for a provider config with credential checks.""" adapter_cls = cls.resolve(config.provider) - missing_credentials = [ - name for name in adapter_cls.required_credentials if not config.credentials.get(name) - ] + + missing_credentials = [] + for req_key in adapter_cls.required_credentials: + credential_value = (config.credentials.get(req_key, "") or "").strip() + if credential_value: + continue + + env_var = f"{config.provider.upper()}_{req_key.upper()}" + env_value = (os.environ.get(env_var, "") or "").strip() + if env_value: + config.credentials[req_key] = env_value + continue + + missing_credentials.append(req_key) + if missing_credentials: raise ValueError( f"Missing credentials for provider '{config.provider}': {missing_credentials}" diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py new file mode 100644 index 0000000..ff08e07 --- /dev/null +++ b/src/mmirage/core/process/batch/status_checker.py @@ -0,0 +1,137 @@ +"""Receiver-side utility for polling provider batch statuses from metadata receipts.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from typing import Dict, List, Mapping, Sequence, TextIO, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.registry import BatchAdapterFactory + + +def extract_unique_provider_batches(metadata_output_path: str) -> List[Tuple[str, str]]: + """Parse metadata JSONL and return unique ``(provider, provider_batch_id)`` pairs. + + Malformed lines and records missing required keys are skipped safely. + """ + unique_pairs: List[Tuple[str, str]] = [] + seen = set() + + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + + try: + record = json.loads(raw) + except json.JSONDecodeError: + continue + + provider = str(record.get("provider", "")).strip().lower() + provider_batch_id = str(record.get("provider_batch_id", "")).strip() + + if not provider or not provider_batch_id: + continue + + pair = (provider, provider_batch_id) + if pair in seen: + continue + seen.add(pair) + unique_pairs.append(pair) + + return unique_pairs + + +def run_status_checker( + metadata_output_path: str, + provider_configs: Mapping[str, BatchProviderConfig], + output: TextIO = sys.stdout, +) -> List[BatchSubmissionResult]: + """Check and print statuses for batches referenced in a metadata receipt file.""" + results: List[BatchSubmissionResult] = [] + counter: Dict[str, Dict[str, int]] = {} + + for provider, provider_batch_id in extract_unique_provider_batches(metadata_output_path): + if provider not in provider_configs: + print(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.", file=output) + provider_counts = counter.setdefault(provider, {}) + provider_counts["skipped"] = provider_counts.get("skipped", 0) + 1 + continue + + config = provider_configs[provider] + adapter = BatchAdapterFactory.from_config(config) + result = adapter.check_batch_status(provider_batch_id=provider_batch_id, config=config) + results.append(result) + + print(f"Batch {provider_batch_id} ({provider}): {result.status}", file=output) + provider_counts = counter.setdefault(provider, {}) + provider_counts[result.status] = provider_counts.get(result.status, 0) + 1 + + print("\n------------ Batch status summary ------------", file=output) + for provider, status_counts in counter.items(): + print(f"Total batches for provider '{provider}':", file=output) + total = sum(status_counts.values()) + for status, count in status_counts.items(): + print(f" {status}: {count}/{total}", file=output) + + return results + + +def _build_provider_configs_from_metadata( + metadata_output_path: str, +) -> Dict[str, BatchProviderConfig]: + provider_names = {provider for provider, _ in extract_unique_provider_batches(metadata_output_path)} + configs: Dict[str, BatchProviderConfig] = {} + + if "openai" in provider_names: + api_key = os.environ.get("OPENAI_API_KEY", "").strip() + if not api_key: + raise ValueError( + "OPENAI_API_KEY is required to check statuses for provider 'openai'." + ) + configs["openai"] = OpenAIBatchConfig(credentials={"api_key": api_key}) + + return configs + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Check provider batch statuses from metadata receipts.") + parser.add_argument( + "--metadata-path", + required=True, + help="Path to metadata JSONL receipt file.", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + args = _build_arg_parser().parse_args(argv) + pairs = extract_unique_provider_batches(args.metadata_path) + if not pairs: + print(f"No provider batch IDs found in metadata file: {args.metadata_path}") + return 0 + + try: + provider_configs = _build_provider_configs_from_metadata(args.metadata_path) + if not provider_configs: + print("No supported provider configurations could be built from metadata.") + return 1 + run_status_checker( + metadata_output_path=args.metadata_path, + provider_configs=provider_configs, + ) + except Exception as exc: + print(f"Status checker failed: {exc}", file=sys.stderr) + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 1108ca3..59c2687 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -377,6 +377,8 @@ def _batch_process_sample( } ] } + if output_var.output_type == "JSON" and output_var.output_schema: + payload["expected_schema"] = list(output_var.output_schema) custom_id = self._next_custom_id(output_var.name, global_i, "text") request = self._batch_adapter.build_request( custom_id=custom_id, @@ -424,6 +426,8 @@ def _batch_process_sample( } ] } + if output_var.output_type == "JSON" and output_var.output_schema: + payload["expected_schema"] = list(output_var.output_schema) custom_id = self._next_custom_id(output_var.name, global_i, "multimodal") request = self._batch_adapter.build_request( custom_id=custom_id, diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py index 3072344..7935df9 100644 --- a/tests/test_batch_adapter_contracts.py +++ b/tests/test_batch_adapter_contracts.py @@ -139,3 +139,14 @@ def test_factory_creates_adapter_when_credentials_are_present(): adapter = BatchAdapterFactory.from_config(config) assert isinstance(adapter, CredentialedTestAdapter) + + +def test_factory_resolves_missing_credentials_from_environment(monkeypatch): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + monkeypatch.setenv("UNIT_API_KEY", "from-env") + config = BatchProviderConfig(provider="unit", credentials={}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CredentialedTestAdapter) + assert config.credentials["api_key"] == "from-env" diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py new file mode 100644 index 0000000..8fe8f7a --- /dev/null +++ b/tests/test_batch_collector.py @@ -0,0 +1,144 @@ +import json + +from mmirage.config.openai_batch import OpenAIBatchConfig + + +def test_collect_and_merge_reconstructs_rows_deterministically(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 2, "c2": 0}, + } + ), + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 2, "c2": 0}, + } + ), + "malformed-line", + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_2", + "custom_id_to_source_index": {"c3": 1}, + } + ), + ] + ), + encoding="utf-8", + ) + + output_path = tmp_path / "merged.jsonl" + + class FakeAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + if provider_batch_id == "batch_1": + return [ + { + "custom_id": "c1", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"q2","answer":"a2"}' + } + } + ] + } + }, + }, + { + "custom_id": "c2", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"q0","answer":"a0"}' + } + } + ] + } + }, + }, + ] + return [ + { + "custom_id": "c3", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"q1","answer":"a1"}' + } + } + ] + } + }, + } + ] + + fake_adapter = FakeAdapter() + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: fake_adapter, + ) + + provider_configs = {"openai": OpenAIBatchConfig(credentials={"api_key": "k"})} + rows = collect_and_merge( + metadata_output_path=str(metadata_path), + provider_configs=provider_configs, + output_path=str(output_path), + ) + + assert [r["source_index"] for r in rows] == [0, 1, 2] + assert [r["custom_id"] for r in rows] == ["c2", "c3", "c1"] + assert [r["conversations"][0]["content"] for r in rows] == ["q0", "q1", "q2"] + assert [r["conversations"][1]["content"] for r in rows] == ["a0", "a1", "a2"] + assert fake_adapter.calls == [("batch_1", "openai"), ("batch_2", "openai")] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert [r["source_index"] for r in written] == [0, 1, 2] + assert [r["conversations"][0]["content"] for r in written] == ["q0", "q1", "q2"] + + +def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): + from mmirage.core.process.batch.collector import collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + try: + collect_and_merge( + metadata_output_path=str(metadata_path), + provider_configs={}, + output_path=str(tmp_path / "out.jsonl"), + ) + assert False, "Expected ValueError" + except ValueError as e: + assert "No provider config" in str(e) diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py new file mode 100644 index 0000000..80ffbbd --- /dev/null +++ b/tests/test_batch_status_checker.py @@ -0,0 +1,88 @@ +from io import StringIO + +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult + + +def test_extract_unique_provider_batches_handles_malformed_and_duplicates(tmp_path): + from mmirage.core.process.batch.status_checker import extract_unique_provider_batches + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + '{"provider":"openai","provider_batch_id":"batch_1"}', + '{"provider":"openai","provider_batch_id":"batch_1"}', + "not-json", + '{"provider":"openai"}', + '{"provider":"openai","provider_batch_id":"batch_2"}', + "", + ] + ), + encoding="utf-8", + ) + + pairs = extract_unique_provider_batches(str(metadata_path)) + + assert pairs == [("openai", "batch_1"), ("openai", "batch_2")] + + +def test_run_status_checker_prints_summary_with_factory_dispatch(tmp_path, monkeypatch): + from mmirage.core.process.batch.status_checker import run_status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + '{"provider":"openai","provider_batch_id":"batch_1"}', + '{"provider":"openai","provider_batch_id":"batch_2"}', + '{"provider":"openai","provider_batch_id":"batch_1"}', + ] + ), + encoding="utf-8", + ) + + class FakeAdapter: + def __init__(self): + self.calls = [] + + def check_batch_status(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + status = "completed" if provider_batch_id == "batch_1" else "in_progress" + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status=status, + submitted_request_count=0, + raw_response={"id": provider_batch_id, "status": status}, + ) + + fake_adapter = FakeAdapter() + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.BatchAdapterFactory.from_config", + lambda config: fake_adapter, + ) + + output = StringIO() + config_map = { + "openai": OpenAIBatchConfig(credentials={"api_key": "k"}), + } + + results = run_status_checker( + metadata_output_path=str(metadata_path), + provider_configs=config_map, + output=output, + ) + + assert [(r.provider_batch_id, r.status) for r in results] == [ + ("batch_1", "completed"), + ("batch_2", "in_progress"), + ] + assert fake_adapter.calls == [ + ("batch_1", "openai"), + ("batch_2", "openai"), + ] + + printed = output.getvalue() + assert "Batch batch_1 (openai): completed" in printed + assert "Batch batch_2 (openai): in_progress" in printed diff --git a/tests/test_integration_receiver.py b/tests/test_integration_receiver.py new file mode 100644 index 0000000..0deb13f --- /dev/null +++ b/tests/test_integration_receiver.py @@ -0,0 +1,101 @@ +import json + +from mmirage.config.openai_batch import OpenAIBatchConfig + + +def test_integration_receiver_reads_receipt_and_writes_merged_output(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import collect_and_merge + + metadata_path = tmp_path / "receipt.text.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_a", + "custom_id_to_source_index": {"id_a": 1, "id_b": 0}, + } + ), + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_b", + "custom_id_to_source_index": {"id_c": 2}, + } + ), + ] + ), + encoding="utf-8", + ) + + output_path = tmp_path / "merged.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + if provider_batch_id == "batch_a": + return [ + { + "custom_id": "id_a", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"What is id_a?","answer":"one"}' + } + } + ] + } + }, + }, + { + "custom_id": "id_b", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"What is id_b?","answer":"zero"}' + } + } + ] + } + }, + }, + ] + return [ + { + "custom_id": "id_c", + "response": { + "body": { + "choices": [ + { + "message": { + "content": '{"question":"What is id_c?","answer":"two"}' + } + } + ] + } + }, + } + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + rows = collect_and_merge( + metadata_output_path=str(metadata_path), + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "test"})}, + output_path=str(output_path), + ) + + assert [r["source_index"] for r in rows] == [0, 1, 2] + assert [r["custom_id"] for r in rows] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in rows] == ["zero", "one", "two"] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert [r["custom_id"] for r in written] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in written] == ["zero", "one", "two"] diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index e436471..bd8efcd 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -31,6 +31,36 @@ def test_openai_build_request_matches_expected_structure(): } +def test_openai_build_request_injects_structured_output_format(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + config = OpenAIBatchConfig(model="gpt-4.1-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [{"role": "user", "content": "hello"}], + "expected_schema": ["question", "answer"], + } + + request = adapter.build_request(custom_id="row-002", payload=payload, config=config) + + assert request["body"]["response_format"] == { + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "strict": True, + "schema": { + "type": "object", + "properties": { + "question": {"type": "string"}, + "answer": {"type": "string"}, + }, + "required": ["question", "answer"], + "additionalProperties": False, + }, + }, + } + + def test_openai_estimate_request_bytes_matches_utf8_json_size(): from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter @@ -197,6 +227,51 @@ def __init__(self, **kwargs): assert result.submitted_request_count == 0 +def test_openai_check_batch_status_falls_back_to_env_api_key(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + monkeypatch.setenv("OPENAI_API_KEY", "env-test-key") + + config = OpenAIBatchConfig(credentials={}) + adapter = OpenAIBatchAdapter() + + result = adapter.check_batch_status(provider_batch_id="batch_env", config=config) + + assert captured["client_kwargs"]["api_key"] == "env-test-key" + assert result.provider_batch_id == "batch_env" + + +def test_openai_check_batch_status_raises_when_no_api_key(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + config = OpenAIBatchConfig(credentials={}) + adapter = OpenAIBatchAdapter() + + with pytest.raises(ValueError, match="OpenAI API key is missing"): + adapter.check_batch_status(provider_batch_id="batch_missing", config=config) + + def test_openai_retrieve_results_downloads_and_parses_jsonl(monkeypatch): from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter @@ -246,8 +321,8 @@ def __init__(self, **kwargs): assert len(rows) == 2 assert rows[0]["custom_id"] == "c1" assert rows[1]["custom_id"] == "c2" - assert rows[0]["generated_text"] == "A" - assert rows[1]["generated_text"] == "B" + assert rows[0]["response"]["body"]["text"] == "A" + assert rows[1]["response"]["body"]["text"] == "B" def test_openai_retrieve_results_raises_if_batch_not_completed(monkeypatch): From b32b4ae8960861e3b6b060dd016b7d4b0afd7430 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 06:24:44 +0100 Subject: [PATCH 14/45] correction on the image handling + tests --- src/mmirage/core/process/batch/collector.py | 71 +++++++++++-------- .../core/process/batch/openai_adapter.py | 52 +++++++++++++- tests/test_batch_collector.py | 65 +++++++++++++++-- tests/test_integration_receiver.py | 8 +-- tests/test_openai_batch_adapter.py | 31 ++++++++ 5 files changed, 187 insertions(+), 40 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 6e4df8f..271e50b 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -80,31 +80,23 @@ def collect_and_merge( config=provider_configs[provider], ) - indexed_rows: MutableMapping[int, Dict[str, Any]] = {} + indexed_rows: MutableMapping[str, Dict[str, Any]] = {} for pair, mapping in pair_to_mapping.items(): results = pair_to_results.get(pair, []) for result_row in results: custom_id = str(result_row.get("custom_id", "")).strip() if not custom_id or custom_id not in mapping: continue - source_index = mapping[custom_id] - parsed = _parse_structured_content(result_row) - indexed_rows[source_index] = { - "source_index": source_index, + row_payload = _build_output_payload(result_row) + indexed_rows[custom_id] = { + "source_index": 0, "custom_id": custom_id, - "conversations": [ - { - "role": "user", - "content": str(parsed.get("question", "")), - }, - { - "role": "assistant", - "content": str(parsed.get("answer", "")), - }, - ], + **row_payload, } - ordered_rows = [indexed_rows[idx] for idx in sorted(indexed_rows.keys())] + ordered_rows = list(indexed_rows.values()) + for idx, row in enumerate(ordered_rows): + row["source_index"] = idx os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: @@ -114,8 +106,35 @@ def collect_and_merge( return ordered_rows -def _parse_structured_content(result_row: Mapping[str, Any]) -> Dict[str, Any]: - # Preferred OpenAI envelope path for Structured Outputs. +def _build_output_payload(result_row: Mapping[str, Any]) -> Dict[str, Any]: + raw_content = _extract_content_string(result_row) + if not raw_content: + return {"caption": ""} + + try: + parsed = json.loads(raw_content) + except json.JSONDecodeError: + return {"caption": raw_content} + + if isinstance(parsed, dict) and ("question" in parsed or "answer" in parsed): + return { + "conversations": [ + { + "role": "user", + "content": str(parsed.get("question", "")), + }, + { + "role": "assistant", + "content": str(parsed.get("answer", "")), + }, + ] + } + + return {"caption": raw_content} + + +def _extract_content_string(result_row: Mapping[str, Any]) -> str: + # Preferred OpenAI envelope path for Structured Outputs / plain responses. response = result_row.get("response") if isinstance(response, Mapping): body = response.get("body") @@ -128,24 +147,14 @@ def _parse_structured_content(result_row: Mapping[str, Any]) -> Dict[str, Any]: if isinstance(message, Mapping): content = message.get("content") if isinstance(content, str): - try: - parsed = json.loads(content) - if isinstance(parsed, dict): - return parsed - except json.JSONDecodeError: - return {} + return content # Fallback for normalized adapter payloads carrying generated_text directly. generated_text = result_row.get("generated_text") if isinstance(generated_text, str): - try: - parsed = json.loads(generated_text) - if isinstance(parsed, dict): - return parsed - except json.JSONDecodeError: - return {} + return generated_text - return {} + return "" def _build_arg_parser() -> argparse.ArgumentParser: diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 5430622..8ce32c9 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -1,7 +1,10 @@ """Concrete OpenAI implementation of batch submission contracts.""" +import base64 +import copy import io import json +import mimetypes import os from typing import Any, Dict, List, Mapping, Sequence @@ -32,9 +35,10 @@ def build_request( config: BatchProviderConfig, ) -> Mapping[str, Any]: openai_config = self._as_openai_config(config) - body = dict(payload) + body = copy.deepcopy(payload) expected_schema = body.pop("expected_schema", None) body.setdefault("model", openai_config.model) + self._convert_local_images_to_data_uris(body) if isinstance(expected_schema, list) and all(isinstance(k, str) for k in expected_schema): properties = {key: {"type": "string"} for key in expected_schema} @@ -59,6 +63,52 @@ def build_request( "body": body, } + @staticmethod + def _convert_local_images_to_data_uris(body: Mapping[str, Any]) -> None: + messages = body.get("messages") + if not isinstance(messages, list): + return + + for message in messages: + if not isinstance(message, dict): + continue + + content = message.get("content") + if not isinstance(content, list): + continue + + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") != "image_url": + continue + + image_url = part.get("image_url") + if not isinstance(image_url, dict): + continue + + url = image_url.get("url") + if not isinstance(url, str): + continue + + # Keep remote/data URLs untouched. + if url.startswith("http://") or url.startswith("https://") or url.startswith("data:"): + continue + + if os.path.exists(url): + image_url["url"] = OpenAIBatchAdapter._local_file_to_data_uri(url) + + @staticmethod + def _local_file_to_data_uri(path: str) -> str: + mime_type, _ = mimetypes.guess_type(path) + if not mime_type: + mime_type = "image/jpeg" + + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("utf-8") + + return f"data:{mime_type};base64,{encoded}" + def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: serialized = json.dumps(request, ensure_ascii=False, separators=(",", ":")) return len(serialized.encode("utf-8")) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index 8fe8f7a..d1ed616 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -107,14 +107,14 @@ def retrieve_results(self, provider_batch_id, config): ) assert [r["source_index"] for r in rows] == [0, 1, 2] - assert [r["custom_id"] for r in rows] == ["c2", "c3", "c1"] - assert [r["conversations"][0]["content"] for r in rows] == ["q0", "q1", "q2"] - assert [r["conversations"][1]["content"] for r in rows] == ["a0", "a1", "a2"] + assert [r["custom_id"] for r in rows] == ["c1", "c2", "c3"] + assert [r["conversations"][0]["content"] for r in rows] == ["q2", "q0", "q1"] + assert [r["conversations"][1]["content"] for r in rows] == ["a2", "a0", "a1"] assert fake_adapter.calls == [("batch_1", "openai"), ("batch_2", "openai")] written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] assert [r["source_index"] for r in written] == [0, 1, 2] - assert [r["conversations"][0]["content"] for r in written] == ["q0", "q1", "q2"] + assert [r["conversations"][0]["content"] for r in written] == ["q2", "q0", "q1"] def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): @@ -142,3 +142,60 @@ def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): assert False, "Expected ValueError" except ValueError as e: assert "No provider config" in str(e) + + +def test_collect_and_merge_outputs_caption_for_plain_text_content(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_plain", + "custom_id_to_source_index": {"img_1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_plain.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + return [ + { + "custom_id": "img_1", + "response": { + "body": { + "choices": [ + { + "message": { + "content": "A black cat sitting on a sofa." + } + } + ] + } + }, + } + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + rows = collect_and_merge( + metadata_output_path=str(metadata_path), + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert rows == [ + { + "source_index": 0, + "custom_id": "img_1", + "caption": "A black cat sitting on a sofa.", + } + ] diff --git a/tests/test_integration_receiver.py b/tests/test_integration_receiver.py index 0deb13f..2b4d0b0 100644 --- a/tests/test_integration_receiver.py +++ b/tests/test_integration_receiver.py @@ -93,9 +93,9 @@ def retrieve_results(self, provider_batch_id, config): ) assert [r["source_index"] for r in rows] == [0, 1, 2] - assert [r["custom_id"] for r in rows] == ["id_b", "id_a", "id_c"] - assert [r["conversations"][1]["content"] for r in rows] == ["zero", "one", "two"] + assert [r["custom_id"] for r in rows] == ["id_a", "id_b", "id_c"] + assert [r["conversations"][1]["content"] for r in rows] == ["one", "zero", "two"] written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] - assert [r["custom_id"] for r in written] == ["id_b", "id_a", "id_c"] - assert [r["conversations"][1]["content"] for r in written] == ["zero", "one", "two"] + assert [r["custom_id"] for r in written] == ["id_a", "id_b", "id_c"] + assert [r["conversations"][1]["content"] for r in written] == ["one", "zero", "two"] diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index bd8efcd..a11739d 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -1,4 +1,5 @@ import json +import base64 import pytest @@ -61,6 +62,36 @@ def test_openai_build_request_injects_structured_output_format(): } +def test_openai_build_request_converts_local_image_path_to_data_uri(tmp_path): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + image_bytes = b"\xff\xd8\xff\xe0testjpeg" + image_path = tmp_path / "sample.jpg" + image_path.write_bytes(image_bytes) + + config = OpenAIBatchConfig(model="gpt-4o-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe image"}, + {"type": "image_url", "image_url": {"url": str(image_path)}}, + ], + } + ] + } + + request = adapter.build_request(custom_id="vision-1", payload=payload, config=config) + + url = request["body"]["messages"][0]["content"][1]["image_url"]["url"] + assert url.startswith("data:image/jpeg;base64,") + + encoded = url.split(",", 1)[1] + assert base64.b64decode(encoded) == image_bytes + + def test_openai_estimate_request_bytes_matches_utf8_json_size(): from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter From c55efaf7aa5602dc9740c83cdd82dda286f99600 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 17:11:47 +0100 Subject: [PATCH 15/45] correct the problem Copilot spotted about a global index which is not global --- .../core/process/processors/llm/llm_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 59c2687..4559a46 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -132,9 +132,9 @@ def _with_metadata_suffix(path: str, suffix: str) -> str: def batch_mode_enabled(self) -> bool: return self._text_orchestrator is not None and self._multimodal_orchestrator is not None - def _next_custom_id(self, output_name: str, global_index: int, modality: str) -> str: + def _next_custom_id(self, output_name: str, modality: str) -> str: self._batch_request_counter += 1 - return f"{output_name}:{modality}:{self._batch_request_counter}:{global_index}" + return f"{output_name}:{modality}:{self._batch_request_counter}" def build_prompt( self, prompt_template: str, vars_samples: List[VariableEnvironment] @@ -379,14 +379,14 @@ def _batch_process_sample( } if output_var.output_type == "JSON" and output_var.output_schema: payload["expected_schema"] = list(output_var.output_schema) - custom_id = self._next_custom_id(output_var.name, global_i, "text") + custom_id = self._next_custom_id(output_var.name, "text") request = self._batch_adapter.build_request( custom_id=custom_id, payload=payload, config=self._batch_provider_config, ) requests.append(dict(request)) - source_indices.append(global_i) + source_indices.append(self._batch_request_counter) self._text_orchestrator.add_requests( requests=requests, @@ -428,14 +428,14 @@ def _batch_process_sample( } if output_var.output_type == "JSON" and output_var.output_schema: payload["expected_schema"] = list(output_var.output_schema) - custom_id = self._next_custom_id(output_var.name, global_i, "multimodal") + custom_id = self._next_custom_id(output_var.name, "multimodal") request = self._batch_adapter.build_request( custom_id=custom_id, payload=payload, config=self._batch_provider_config, ) requests.append(dict(request)) - source_indices.append(global_i) + source_indices.append(self._batch_request_counter) self._multimodal_orchestrator.add_requests( requests=requests, From fead12e5d1865b8a4948ad3e2f7fdd856903095b Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 17:54:07 +0100 Subject: [PATCH 16/45] correction of problem of sample id spotted by copilot --- src/mmirage/core/process/processors/llm/llm_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 4559a46..fe5591a 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -357,6 +357,7 @@ def _batch_process_sample( nb_samples = len(batch) text_only_indices: List[int] = [] multimodal_indices: List[int] = [] + index_to_custom_id: Dict[int, str] = {} for i in range(nb_samples): if batch[i].has_images(): multimodal_indices.append(i) @@ -380,6 +381,7 @@ def _batch_process_sample( if output_var.output_type == "JSON" and output_var.output_schema: payload["expected_schema"] = list(output_var.output_schema) custom_id = self._next_custom_id(output_var.name, "text") + index_to_custom_id[global_i] = custom_id request = self._batch_adapter.build_request( custom_id=custom_id, payload=payload, @@ -429,6 +431,7 @@ def _batch_process_sample( if output_var.output_type == "JSON" and output_var.output_schema: payload["expected_schema"] = list(output_var.output_schema) custom_id = self._next_custom_id(output_var.name, "multimodal") + index_to_custom_id[global_i] = custom_id request = self._batch_adapter.build_request( custom_id=custom_id, payload=payload, @@ -449,7 +452,8 @@ def _batch_process_sample( placeholders: List[VariableEnvironment] = [] for i in range(nb_samples): - placeholder = f"__BATCH_SUBMITTED__:{output_var.name}:{i}" + unique_id = index_to_custom_id.get(i, f"unknown:{i}") + placeholder = f"__BATCH_SUBMITTED__:{unique_id}" placeholders.append(batch[i].with_variable(output_var.name, placeholder)) return placeholders From cead89b0bac0c48c0d9771d46fab3db34dc42a4a Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 18:18:17 +0100 Subject: [PATCH 17/45] correction of the file naming, spotted by copilot, to allow sharding --- .../core/process/processors/llm/llm_processor.py | 12 +++++++----- tests/test_integration_batch_pipeline.py | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index fe5591a..5996e1e 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -6,6 +6,7 @@ import json import logging from typing import Any, Dict, List, Optional, Tuple +import uuid import jinja2 import sglang as sgl @@ -100,13 +101,14 @@ def _setup_batch_runtime(self) -> None: openai_cfg = OpenAIBatchConfig(**provider_cfg_raw) self._batch_provider_config = openai_cfg self._batch_adapter = BatchAdapterFactory.from_config(openai_cfg) + run_id = uuid.uuid4().hex[:6] self._text_orchestrator = BatchSubmissionOrchestrator( adapter=self._batch_adapter, config=replace( openai_cfg, metadata_output_path=self._with_metadata_suffix( - openai_cfg.metadata_output_path, "text" + openai_cfg.metadata_output_path, "text", run_id ), ), ) @@ -115,18 +117,18 @@ def _setup_batch_runtime(self) -> None: config=replace( openai_cfg, metadata_output_path=self._with_metadata_suffix( - openai_cfg.metadata_output_path, "multimodal" + openai_cfg.metadata_output_path, "multimodal", run_id ), ), ) @staticmethod - def _with_metadata_suffix(path: str, suffix: str) -> str: + def _with_metadata_suffix(path: str, suffix: str, run_id: str) -> str: if not path: return "" if path.endswith(".jsonl"): - return path[:-6] + f".{suffix}.jsonl" - return f"{path}.{suffix}.jsonl" + return path[:-6] + f".{suffix}.{run_id}.jsonl" + return f"{path}.{suffix}.{run_id}.jsonl" @property def batch_mode_enabled(self) -> bool: diff --git a/tests/test_integration_batch_pipeline.py b/tests/test_integration_batch_pipeline.py index 86351d1..77c73a0 100644 --- a/tests/test_integration_batch_pipeline.py +++ b/tests/test_integration_batch_pipeline.py @@ -138,8 +138,9 @@ def rewrite_batch(batch, mapper, renderer): assert all(isinstance(v, str) and v.startswith("__BATCH_SUBMITTED__:answer:") for v in answers) # 3) Metadata receipts are written and include both full_chunk and finalize flush reasons. - metadata_text_path = tmp_path / "batch_receipts.text.jsonl" - assert metadata_text_path.exists() + metadata_text_matches = sorted(tmp_path.glob("batch_receipts.text.*.jsonl")) + assert len(metadata_text_matches) == 1 + metadata_text_path = metadata_text_matches[0] records = [ json.loads(line) From 03eed23fba35fe2bacc0c192f23a00af4785c5cf Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 18:24:36 +0100 Subject: [PATCH 18/45] add an import --- src/mmirage/shard_process.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index a656e79..abe4bad 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -12,6 +12,7 @@ from mmirage.core.loader.base import BaseDataLoaderConfig, DatasetLike from mmirage.core.process.mapper import MMIRAGEMapper +from mmirage.core.process.processors.llm.llm_processor import LLMProcessor # noqa: F401 from mmirage.config.utils import load_mmirage_config from mmirage.core.writer.renderer import TemplateRenderer From d60e3370ba5d719700eb3386e91047773dec69a8 Mon Sep 17 00:00:00 2001 From: legstar67 <86873782+legstar67@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:30:39 +0100 Subject: [PATCH 19/45] Update src/mmirage/core/process/batch/collector.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/mmirage/core/process/batch/collector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 271e50b..615bf59 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -89,14 +89,12 @@ def collect_and_merge( continue row_payload = _build_output_payload(result_row) indexed_rows[custom_id] = { - "source_index": 0, + "source_index": mapping[custom_id], "custom_id": custom_id, **row_payload, } - ordered_rows = list(indexed_rows.values()) - for idx, row in enumerate(ordered_rows): - row["source_index"] = idx + ordered_rows = sorted(indexed_rows.values(), key=lambda row: row["source_index"]) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: From c64274485980e63840705d8f6835c02d58e78da6 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Thu, 26 Mar 2026 18:46:42 +0100 Subject: [PATCH 20/45] copilot suggestion --- src/mmirage/core/process/batch/collector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 615bf59..40accdc 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -89,12 +89,12 @@ def collect_and_merge( continue row_payload = _build_output_payload(result_row) indexed_rows[custom_id] = { - "source_index": mapping[custom_id], + "source_index": int(mapping.get(custom_id, 0)), "custom_id": custom_id, **row_payload, } - ordered_rows = sorted(indexed_rows.values(), key=lambda row: row["source_index"]) + ordered_rows = sorted(indexed_rows.values(), key=lambda row: row.get("source_index", 0)) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: From c7d4bac2c55ae03953184d408f744a2146c299ed Mon Sep 17 00:00:00 2001 From: legstar67 Date: Fri, 27 Mar 2026 01:00:07 +0100 Subject: [PATCH 21/45] template application changed + small correction --- src/mmirage/core/process/processors/llm/llm_processor.py | 6 +++--- tests/test_batch_collector.py | 8 ++++---- tests/test_integration_receiver.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 5996e1e..895f7f8 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -367,16 +367,16 @@ def _batch_process_sample( text_only_indices.append(i) if text_only_indices: - text_only_envs = [batch[i] for i in text_only_indices] - prompts = self.build_prompt(output_var.prompt, text_only_envs) + jinja_template = jinja2.Template(output_var.prompt) requests: List[Dict[str, Any]] = [] source_indices: List[int] = [] for local_i, global_i in enumerate(text_only_indices): + base_prompt = jinja_template.render(**batch[global_i].to_dict()) payload = { "messages": [ { "role": "user", - "content": prompts[local_i], + "content": base_prompt, } ] } diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index d1ed616..30b2f1c 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -107,14 +107,14 @@ def retrieve_results(self, provider_batch_id, config): ) assert [r["source_index"] for r in rows] == [0, 1, 2] - assert [r["custom_id"] for r in rows] == ["c1", "c2", "c3"] - assert [r["conversations"][0]["content"] for r in rows] == ["q2", "q0", "q1"] - assert [r["conversations"][1]["content"] for r in rows] == ["a2", "a0", "a1"] + assert [r["custom_id"] for r in rows] == ["c2", "c3", "c1"] + assert [r["conversations"][0]["content"] for r in rows] == ["q0", "q1", "q2"] + assert [r["conversations"][1]["content"] for r in rows] == ["a0", "a1", "a2"] assert fake_adapter.calls == [("batch_1", "openai"), ("batch_2", "openai")] written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] assert [r["source_index"] for r in written] == [0, 1, 2] - assert [r["conversations"][0]["content"] for r in written] == ["q2", "q0", "q1"] + assert [r["conversations"][0]["content"] for r in written] == ["q0", "q1", "q2"] def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): diff --git a/tests/test_integration_receiver.py b/tests/test_integration_receiver.py index 2b4d0b0..0deb13f 100644 --- a/tests/test_integration_receiver.py +++ b/tests/test_integration_receiver.py @@ -93,9 +93,9 @@ def retrieve_results(self, provider_batch_id, config): ) assert [r["source_index"] for r in rows] == [0, 1, 2] - assert [r["custom_id"] for r in rows] == ["id_a", "id_b", "id_c"] - assert [r["conversations"][1]["content"] for r in rows] == ["one", "zero", "two"] + assert [r["custom_id"] for r in rows] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in rows] == ["zero", "one", "two"] written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] - assert [r["custom_id"] for r in written] == ["id_a", "id_b", "id_c"] - assert [r["conversations"][1]["content"] for r in written] == ["one", "zero", "two"] + assert [r["custom_id"] for r in written] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in written] == ["zero", "one", "two"] From c5e75a202c698b13fe413cc5f667a16d1b5f9420 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Fri, 27 Mar 2026 02:33:50 +0100 Subject: [PATCH 22/45] suggestion by Copilot, mostly changing the CLI arguement to accept list of files --- src/mmirage/core/process/batch/collector.py | 36 ++++++----- .../core/process/batch/status_checker.py | 62 ++++++++++++------- .../core/process/processors/llm/config.py | 5 +- .../process/processors/llm/llm_processor.py | 11 ++-- 4 files changed, 69 insertions(+), 45 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 40accdc..6ffd19c 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -13,19 +13,26 @@ from mmirage.core.process.batch.registry import BatchAdapterFactory -def _read_metadata_records(metadata_output_path: str) -> List[Dict[str, Any]]: +def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: + if isinstance(metadata_paths, str): + return [metadata_paths] + return [str(path) for path in metadata_paths] + + +def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]: records: List[Dict[str, Any]] = [] - with open(metadata_output_path, "r", encoding="utf-8") as f: - for line in f: - raw = line.strip() - if not raw: - continue - try: - parsed = json.loads(raw) - except json.JSONDecodeError: - continue - if isinstance(parsed, dict): - records.append(parsed) + for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict): + records.append(parsed) return records @@ -56,7 +63,7 @@ def _aggregate_batch_mappings( def collect_and_merge( - metadata_output_path: str, + metadata_output_path: str | Sequence[str], provider_configs: Mapping[str, BatchProviderConfig], output_path: str, ) -> List[Dict[str, Any]]: @@ -161,8 +168,9 @@ def _build_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--metadata-path", + nargs="+", required=True, - help="Path to metadata JSONL receipt file.", + help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.", ) parser.add_argument( "--output-path", diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index ff08e07..b54f679 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -14,7 +14,30 @@ from mmirage.core.process.batch.registry import BatchAdapterFactory -def extract_unique_provider_batches(metadata_output_path: str) -> List[Tuple[str, str]]: +def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: + if isinstance(metadata_paths, str): + return [metadata_paths] + return [str(path) for path in metadata_paths] + + +def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, str]]: + records: List[Dict[str, str]] = [] + for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + try: + record = json.loads(raw) + except json.JSONDecodeError: + continue + if isinstance(record, dict): + records.append(record) + return records + + +def extract_unique_provider_batches(metadata_output_path: str | Sequence[str]) -> List[Tuple[str, str]]: """Parse metadata JSONL and return unique ``(provider, provider_batch_id)`` pairs. Malformed lines and records missing required keys are skipped safely. @@ -22,34 +45,24 @@ def extract_unique_provider_batches(metadata_output_path: str) -> List[Tuple[str unique_pairs: List[Tuple[str, str]] = [] seen = set() - with open(metadata_output_path, "r", encoding="utf-8") as f: - for line in f: - raw = line.strip() - if not raw: - continue + for record in _read_metadata_records(metadata_output_path): + provider = str(record.get("provider", "")).strip().lower() + provider_batch_id = str(record.get("provider_batch_id", "")).strip() - try: - record = json.loads(raw) - except json.JSONDecodeError: - continue - - provider = str(record.get("provider", "")).strip().lower() - provider_batch_id = str(record.get("provider_batch_id", "")).strip() - - if not provider or not provider_batch_id: - continue + if not provider or not provider_batch_id: + continue - pair = (provider, provider_batch_id) - if pair in seen: - continue - seen.add(pair) - unique_pairs.append(pair) + pair = (provider, provider_batch_id) + if pair in seen: + continue + seen.add(pair) + unique_pairs.append(pair) return unique_pairs def run_status_checker( - metadata_output_path: str, + metadata_output_path: str | Sequence[str], provider_configs: Mapping[str, BatchProviderConfig], output: TextIO = sys.stdout, ) -> List[BatchSubmissionResult]: @@ -84,7 +97,7 @@ def run_status_checker( def _build_provider_configs_from_metadata( - metadata_output_path: str, + metadata_output_path: str | Sequence[str], ) -> Dict[str, BatchProviderConfig]: provider_names = {provider for provider, _ in extract_unique_provider_batches(metadata_output_path)} configs: Dict[str, BatchProviderConfig] = {} @@ -104,8 +117,9 @@ def _build_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Check provider batch statuses from metadata receipts.") parser.add_argument( "--metadata-path", + nargs="+", required=True, - help="Path to metadata JSONL receipt file.", + help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.", ) return parser diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 762d095..3a1efbf 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -18,20 +18,19 @@ def _parse_tp_size_from_env() -> int: """Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable. - + Defensively parses the environment variable, handling invalid values: - Returns 1 if the variable is None or empty - Strips whitespace before parsing - Returns 1 for non-integer values - Returns 1 for values <= 0 - + Returns: Tensor parallelism size (>= 1), defaults to 1 on any parsing error. """ env_value = os.environ.get("SLURM_GPUS_ON_NODE") if not env_value: return 1 - try: tp_size = int(env_value.strip()) # Ensure tp_size is positive (must be >= 1) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 895f7f8..d18911d 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -63,7 +63,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: """ super().__init__(engine_args, **kwargs) provider_cfg_raw = dict(getattr(engine_args, "batch_provider", {}) or {}) - batch_mode_requested = bool(provider_cfg_raw.get("enabled", False)) + batch_mode_requested = bool(provider_cfg_raw.get("enabled", True)) # In provider-batch mode we only build payloads/metadata and should not # initialize GPU-backed SGLang runtime. @@ -82,6 +82,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: self._text_orchestrator: Optional[BatchSubmissionOrchestrator] = None self._multimodal_orchestrator: Optional[BatchSubmissionOrchestrator] = None self._batch_request_counter = 0 + self._global_row_offset = 0 self._setup_batch_runtime() def _setup_batch_runtime(self) -> None: @@ -89,7 +90,7 @@ def _setup_batch_runtime(self) -> None: if not provider_cfg_raw: return - if not provider_cfg_raw.get("enabled", False): + if not provider_cfg_raw.get("enabled", True): return provider = str(provider_cfg_raw.get("provider", "openai")).strip().lower() @@ -390,7 +391,7 @@ def _batch_process_sample( config=self._batch_provider_config, ) requests.append(dict(request)) - source_indices.append(self._batch_request_counter) + source_indices.append(self._global_row_offset + global_i) self._text_orchestrator.add_requests( requests=requests, @@ -440,7 +441,7 @@ def _batch_process_sample( config=self._batch_provider_config, ) requests.append(dict(request)) - source_indices.append(self._batch_request_counter) + source_indices.append(self._global_row_offset + global_i) self._multimodal_orchestrator.add_requests( requests=requests, @@ -458,6 +459,8 @@ def _batch_process_sample( placeholder = f"__BATCH_SUBMITTED__:{unique_id}" placeholders.append(batch[i].with_variable(output_var.name, placeholder)) + self._global_row_offset += nb_samples + return placeholders def finalize(self) -> None: From 1ce4e2c4ef56d0137241b59bc9063d6185c70ecc Mon Sep 17 00:00:00 2001 From: legstar67 Date: Fri, 27 Mar 2026 21:00:52 +0100 Subject: [PATCH 23/45] small correction, to respect backward compatibility --- .../process/processors/llm/llm_processor.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index d18911d..1a9910c 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -62,19 +62,26 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: **kwargs: Additional arguments passed to base class. """ super().__init__(engine_args, **kwargs) - provider_cfg_raw = dict(getattr(engine_args, "batch_provider", {}) or {}) - batch_mode_requested = bool(provider_cfg_raw.get("enabled", True)) + + batch_provider_attrs = getattr(engine_args, "batch_provider", None) + if batch_provider_attrs is None: + is_provider_batch_enabled = False + else: + provider_cfg_raw = dict(batch_provider_attrs) + is_provider_batch_enabled = bool(provider_cfg_raw.get("enabled", True)) # In provider-batch mode we only build payloads/metadata and should not # initialize GPU-backed SGLang runtime. self.llm = None - if not batch_mode_requested: + self.tokenizer = None + if not is_provider_batch_enabled: self.llm = sgl.Engine(**asdict(engine_args.server_args)) + self.tokenizer = AutoTokenizer.from_pretrained( + engine_args.server_args.model_path, + trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), + ) + - self.tokenizer = AutoTokenizer.from_pretrained( - engine_args.server_args.model_path, - trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), - ) self.sampling_params = engine_args.default_sampling_params self.chat_template = engine_args.chat_template self._batch_adapter = None From 30b52b8f89ddbd9996ebc2542bf7df7130a8193e Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sat, 28 Mar 2026 00:45:43 +0100 Subject: [PATCH 24/45] small correction on the config for vision --- configs/config_mock_openai_batch_vision.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/config_mock_openai_batch_vision.yaml b/configs/config_mock_openai_batch_vision.yaml index e89b571..49f35f1 100644 --- a/configs/config_mock_openai_batch_vision.yaml +++ b/configs/config_mock_openai_batch_vision.yaml @@ -13,7 +13,7 @@ processors: enabled: true provider: openai model: gpt-4o-mini - metadata_output_path: tests/output/batch_metadata_vision.jsonl + metadata_output_path: tests/output/batch_metadata.jsonl credentials: api_key: "" From 71f9d6f355b4fc21d2a6b251432ad2bc2f83860f Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sat, 2 May 2026 01:03:35 +0200 Subject: [PATCH 25/45] big improvement : making the 3 steps of the pipeline really provider agnostic , hardcoded openai logic removed, + unitary tests + partial docs --- src/mmirage/core/process/batch/collector.py | 100 ++++--- .../core/process/batch/provider_resolution.py | 195 +++++++++++++ .../core/process/batch/status_checker.py | 74 +++-- .../process/processors/llm/llm_processor.py | 24 +- tests/test_batch_adapter_contracts.py | 49 ++++ tests/test_batch_collector.py | 272 ++++++++++++++---- tests/test_batch_orchestrator.py | 89 ++++++ tests/test_batch_status_checker.py | 120 +++++++- tests/test_integration_receiver.py | 41 +-- 9 files changed, 795 insertions(+), 169 deletions(-) create mode 100644 src/mmirage/core/process/batch/provider_resolution.py diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 6ffd19c..0f42303 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -1,4 +1,9 @@ -"""Receiver-side utility for collecting provider results and merging by source row index.""" +"""Collect provider batch receipts and merge completed rows by source index. + +The receiver consumes one or more metadata receipt files, resolves the provider +configuration for each recorded batch, fetches the provider results, and writes a +single JSONL file ordered by the original source row index. +""" from __future__ import annotations @@ -9,17 +14,28 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple from mmirage.config.batch_provider import BatchProviderConfig -from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.provider_resolution import resolve_provider_configs from mmirage.core.process.batch.registry import BatchAdapterFactory def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: + """Return metadata paths as a concrete list. + + Accepts either a single string or a sequence so the CLI and internal callers + can share the same entry point without special-casing file counts. + """ if isinstance(metadata_paths, str): return [metadata_paths] return [str(path) for path in metadata_paths] def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]: + """Load valid JSON objects from one or more receipt files. + + Malformed lines are skipped so partially written or noisy receipt files do + not stop collection. Only JSON objects are retained because downstream + resolution depends on keyed metadata. + """ records: List[Dict[str, Any]] = [] for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): with open(metadata_output_path, "r", encoding="utf-8") as f: @@ -39,6 +55,11 @@ def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[D def _aggregate_batch_mappings( records: Sequence[Mapping[str, Any]], ) -> Dict[Tuple[str, str], Dict[str, int]]: + """Group source-index mappings by provider and provider batch ID. + + Later receipts for the same provider batch overwrite earlier entries for the + same custom ID, which keeps the latest parsed mapping authoritative. + """ aggregated: Dict[Tuple[str, str], Dict[str, int]] = {} for record in records: @@ -63,12 +84,24 @@ def _aggregate_batch_mappings( def collect_and_merge( - metadata_output_path: str | Sequence[str], + records: Sequence[Mapping[str, Any]], provider_configs: Mapping[str, BatchProviderConfig], output_path: str, ) -> List[Dict[str, Any]]: - """Collect completed results and reconstruct rows in source index order.""" - records = _read_metadata_records(metadata_output_path) + """Fetch provider outputs and write merged rows in source index order. + + Args: + records: Parsed receipt metadata containing provider batch references. + provider_configs: Provider-specific configuration keyed by normalized + provider name. + output_path: Destination JSONL path for the merged output. + + Returns: + The ordered rows that were written to disk. + + Raises: + ValueError: If a receipt references a provider that cannot be resolved. + """ pair_to_mapping = _aggregate_batch_mappings(records) adapters: Dict[str, Any] = {} @@ -112,6 +145,12 @@ def collect_and_merge( def _build_output_payload(result_row: Mapping[str, Any]) -> Dict[str, Any]: + """Convert provider content into the receiver's output schema. + + The collector preserves raw text for opaque generations, but maps structured + question/answer JSON into a conversation format expected by downstream + consumers. + """ raw_content = _extract_content_string(result_row) if not raw_content: return {"caption": ""} @@ -139,27 +178,12 @@ def _build_output_payload(result_row: Mapping[str, Any]) -> Dict[str, Any]: def _extract_content_string(result_row: Mapping[str, Any]) -> str: - # Preferred OpenAI envelope path for Structured Outputs / plain responses. - response = result_row.get("response") - if isinstance(response, Mapping): - body = response.get("body") - if isinstance(body, Mapping): - choices = body.get("choices") - if isinstance(choices, list) and choices: - first_choice = choices[0] - if isinstance(first_choice, Mapping): - message = first_choice.get("message") - if isinstance(message, Mapping): - content = message.get("content") - if isinstance(content, str): - return content - - # Fallback for normalized adapter payloads carrying generated_text directly. - generated_text = result_row.get("generated_text") - if isinstance(generated_text, str): - return generated_text - - return "" + """Return the generated text payload as a string. + + The collector treats missing content as empty output rather than a hard + failure so incomplete provider responses do not block the merge. + """ + return str(result_row.get("generated_text", "")) def _build_arg_parser() -> argparse.ArgumentParser: @@ -177,21 +201,29 @@ def _build_arg_parser() -> argparse.ArgumentParser: required=True, help="Path to write merged JSONL output.", ) + parser.add_argument( + "--config", + required=True, + help="Path to the YAML configuration file", + ) return parser def main(argv: Sequence[str] | None = None) -> int: - args = _build_arg_parser().parse_args(argv) + """Run the collector CLI. - api_key = os.environ.get("OPENAI_API_KEY", "").strip() - if not api_key: - raise ValueError("OPENAI_API_KEY is required for collector execution.") + Reads receipt metadata, resolves provider configs from the supplied MMIRAGE + configuration, and writes the merged JSONL output path passed on the command + line. + """ + args = _build_arg_parser().parse_args(argv) + from mmirage.config.utils import load_mmirage_config - provider_configs: Dict[str, BatchProviderConfig] = { - "openai": OpenAIBatchConfig(credentials={"api_key": api_key}) - } + records = _read_metadata_records(args.metadata_path) + cfg = load_mmirage_config(args.config) + provider_configs = resolve_provider_configs(records, cfg) - rows = collect_and_merge(args.metadata_path, provider_configs, args.output_path) + rows = collect_and_merge(records, provider_configs, args.output_path) print(f"Merged {len(rows)} rows and saved to {args.output_path}") return 0 diff --git a/src/mmirage/core/process/batch/provider_resolution.py b/src/mmirage/core/process/batch/provider_resolution.py new file mode 100644 index 0000000..ed66886 --- /dev/null +++ b/src/mmirage/core/process/batch/provider_resolution.py @@ -0,0 +1,195 @@ +"""Resolve batch provider configs from YAML and metadata inputs. + +These helpers decouple config parsing from metadata inspection so the same +provider configuration logic can be reused by receiver and submission flows. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Type + +from mmirage.config.batch_provider import BatchProviderConfig + +if TYPE_CHECKING: + from mmirage.config.config import MMirageConfig + + +class BatchProviderConfigRegistry: + """Registry for provider-specific batch config classes. + + Built-ins are registered lazily to avoid import cycles and keep config + resolution available in lightweight contexts. + """ + + _registry: Dict[str, Type[BatchProviderConfig]] = {} + _bootstrapped: bool = False + + @classmethod + def _bootstrap_builtin_configs(cls) -> None: + if cls._bootstrapped: + return + + from mmirage.config.openai_batch import OpenAIBatchConfig + + cls.register("openai", OpenAIBatchConfig) + cls._bootstrapped = True + + @classmethod + def register(cls, provider: str, config_cls: Type[BatchProviderConfig]) -> None: + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + cls._registry[provider_key] = config_cls + + @classmethod + def clear(cls) -> None: + cls._registry.clear() + cls._bootstrapped = False + + @classmethod + def get_config_cls( + cls, + provider: str, + default: Type[BatchProviderConfig] | None = None, + ) -> Type[BatchProviderConfig]: + cls._bootstrap_builtin_configs() + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + if provider_key in cls._registry: + return cls._registry[provider_key] + if default is not None: + return default + raise ValueError( + f"Unknown batch provider '{provider}'. Available providers: {list(cls._registry.keys())}" + ) + + +def _discover_required_providers(metadata_records: Sequence[Mapping[str, Any]]) -> List[str]: + providers: List[str] = [] + seen = set() + for record in metadata_records: + provider = str(record.get("provider", "")).strip().lower() + if not provider or provider in seen: + continue + seen.add(provider) + providers.append(provider) + return providers + + +def _extract_batch_provider_blocks(cfg: MMirageConfig) -> Dict[str, Dict[str, Any]]: + """Collect raw batch_provider blocks keyed by provider. + + Raises ValueError on duplicate provider definitions to avoid ambiguous + config resolution. + """ + provider_blocks: Dict[str, Dict[str, Any]] = {} + for processor_cfg in cfg.processors: + raw_block = dict(getattr(processor_cfg, "batch_provider", {}) or {}) + if not raw_block: + continue + + provider = str(raw_block.get("provider", "openai")).strip().lower() + if not provider: + continue + + if provider in provider_blocks: + raise ValueError( + f"Duplicate batch_provider blocks found for provider '{provider}' in config processors." + ) + + provider_blocks[provider] = raw_block + + return provider_blocks + + +def _instantiate_provider_config(provider: str, raw_block: Mapping[str, Any]) -> BatchProviderConfig: + """Instantiate the provider config, falling back to the shared base config.""" + payload = dict(raw_block) + payload.setdefault("provider", provider) + + config_cls = BatchProviderConfigRegistry.get_config_cls( + provider, + default=BatchProviderConfig, + ) + return config_cls(**payload) + + +def resolve_single_provider_config(raw_block: Mapping[str, Any]) -> BatchProviderConfig: + """Resolve a single provider config from a raw batch_provider block. + + Defaults to the OpenAI provider for backward compatibility and raises + ValueError for unknown providers or invalid config payloads. + """ + payload = dict(raw_block or {}) + provider = str(payload.get("provider", "openai")).strip().lower() + if not provider: + provider = "openai" + payload["provider"] = provider + + try: + BatchProviderConfigRegistry.get_config_cls(provider) + except ValueError as exc: + raise ValueError(str(exc)) from exc + + try: + return _instantiate_provider_config(provider, payload) + except Exception as exc: + raise ValueError( + f"Failed to instantiate batch provider config for '{provider}': {exc}" + ) from exc + + +def build_all_provider_configs(cfg: "MMirageConfig") -> Dict[str, BatchProviderConfig]: + """Build provider configs for every batch_provider block in the YAML. + + Raises ValueError when any provider config fails to instantiate. + """ + provider_blocks = _extract_batch_provider_blocks(cfg) + if not provider_blocks: + return {} + + resolved: Dict[str, BatchProviderConfig] = {} + for provider, raw_block in provider_blocks.items(): + try: + resolved[provider] = _instantiate_provider_config(provider, raw_block) + except Exception as exc: + raise ValueError( + f"Failed to instantiate batch provider config for '{provider}': {exc}" + ) from exc + + return resolved + + +def resolve_provider_configs( + metadata_records: Sequence[Mapping[str, Any]], + cfg: "MMirageConfig", +) -> Dict[str, BatchProviderConfig]: + """Resolve provider configs required by receiver metadata. + + Args: + metadata_records: Parsed metadata JSONL records used to discover which + providers are required by the receiver command. + cfg: Loaded YAML config object from ``load_mmirage_config``. + + Returns: + Mapping from normalized provider name to instantiated provider config. + + Raises: + ValueError: If metadata references a provider missing from config + processor ``batch_provider`` blocks or if provider config + instantiation fails. + """ + available_configs = build_all_provider_configs(cfg) + required_providers = _discover_required_providers(metadata_records) + if not required_providers: + return {} + + missing = [provider for provider in required_providers if provider not in available_configs] + if missing: + raise ValueError( + "Metadata references provider(s) missing from YAML batch_provider config: " + f"{missing}. Check cfg.processors[*].batch_provider." + ) + + return {provider: available_configs[provider] for provider in required_providers} diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index b54f679..6b90760 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -1,16 +1,19 @@ -"""Receiver-side utility for polling provider batch statuses from metadata receipts.""" +"""Receiver-side helper to check provider batch status from metadata receipts. + +Designed for CLI use against JSONL receipt files. Skips malformed lines and +missing keys to keep status checks resilient to partial metadata corruption. +""" from __future__ import annotations import argparse import json -import os import sys -from typing import Dict, List, Mapping, Sequence, TextIO, Tuple +from typing import Any, Dict, List, Mapping, Sequence, TextIO, Tuple from mmirage.config.batch_provider import BatchProviderConfig -from mmirage.config.openai_batch import OpenAIBatchConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.provider_resolution import resolve_provider_configs from mmirage.core.process.batch.registry import BatchAdapterFactory @@ -21,6 +24,11 @@ def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, str]]: + """Load JSONL metadata records from one or more files. + + Lines that are empty or invalid JSON are ignored to allow best-effort + status checks when receipt files are partially corrupted. + """ records: List[Dict[str, str]] = [] for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): with open(metadata_output_path, "r", encoding="utf-8") as f: @@ -37,15 +45,16 @@ def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[D return records -def extract_unique_provider_batches(metadata_output_path: str | Sequence[str]) -> List[Tuple[str, str]]: - """Parse metadata JSONL and return unique ``(provider, provider_batch_id)`` pairs. +def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any]]) -> List[Tuple[str, str]]: + """Return unique ``(provider, provider_batch_id)`` pairs. - Malformed lines and records missing required keys are skipped safely. + Normalizes provider names to lowercase and ignores records that do not + provide both keys, preventing accidental calls with incomplete metadata. """ unique_pairs: List[Tuple[str, str]] = [] seen = set() - for record in _read_metadata_records(metadata_output_path): + for record in metadata_records: provider = str(record.get("provider", "")).strip().lower() provider_batch_id = str(record.get("provider_batch_id", "")).strip() @@ -62,15 +71,20 @@ def extract_unique_provider_batches(metadata_output_path: str | Sequence[str]) - def run_status_checker( - metadata_output_path: str | Sequence[str], + metadata_records: Sequence[Mapping[str, Any]], provider_configs: Mapping[str, BatchProviderConfig], output: TextIO = sys.stdout, ) -> List[BatchSubmissionResult]: - """Check and print statuses for batches referenced in a metadata receipt file.""" + """Check batch status for each referenced provider batch. + + Prints a per-batch line and a per-provider summary. Providers missing + from ``provider_configs`` are skipped rather than failing the run so + partial configurations still yield useful status output. + """ results: List[BatchSubmissionResult] = [] counter: Dict[str, Dict[str, int]] = {} - for provider, provider_batch_id in extract_unique_provider_batches(metadata_output_path): + for provider, provider_batch_id in extract_unique_provider_batches(metadata_records): if provider not in provider_configs: print(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.", file=output) provider_counts = counter.setdefault(provider, {}) @@ -96,24 +110,8 @@ def run_status_checker( return results -def _build_provider_configs_from_metadata( - metadata_output_path: str | Sequence[str], -) -> Dict[str, BatchProviderConfig]: - provider_names = {provider for provider, _ in extract_unique_provider_batches(metadata_output_path)} - configs: Dict[str, BatchProviderConfig] = {} - - if "openai" in provider_names: - api_key = os.environ.get("OPENAI_API_KEY", "").strip() - if not api_key: - raise ValueError( - "OPENAI_API_KEY is required to check statuses for provider 'openai'." - ) - configs["openai"] = OpenAIBatchConfig(credentials={"api_key": api_key}) - - return configs - - def _build_arg_parser() -> argparse.ArgumentParser: + """Build the CLI parser for the status-check entry point.""" parser = argparse.ArgumentParser(description="Check provider batch statuses from metadata receipts.") parser.add_argument( "--metadata-path", @@ -121,23 +119,37 @@ def _build_arg_parser() -> argparse.ArgumentParser: required=True, help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.", ) + parser.add_argument( + "--config", + required=True, + help="Path to the YAML configuration file", + ) return parser def main(argv: Sequence[str] | None = None) -> int: + """CLI entry point that returns a process-style status code. + + Returns 0 on success or no batches found, and 1 on configuration or + provider resolution failures. + """ args = _build_arg_parser().parse_args(argv) - pairs = extract_unique_provider_batches(args.metadata_path) + from mmirage.config.utils import load_mmirage_config + + records = _read_metadata_records(args.metadata_path) + pairs = extract_unique_provider_batches(records) if not pairs: print(f"No provider batch IDs found in metadata file: {args.metadata_path}") return 0 try: - provider_configs = _build_provider_configs_from_metadata(args.metadata_path) + cfg = load_mmirage_config(args.config) + provider_configs = resolve_provider_configs(records, cfg) if not provider_configs: print("No supported provider configurations could be built from metadata.") return 1 run_status_checker( - metadata_output_path=args.metadata_path, + metadata_records=records, provider_configs=provider_configs, ) except Exception as exc: diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 1a9910c..0a29600 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -14,10 +14,10 @@ from mmirage.core.process.base import BaseProcessor, ProcessorRegistry from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator +from mmirage.core.process.batch.provider_resolution import resolve_single_provider_config from mmirage.core.process.batch.registry import BatchAdapterFactory from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment -from mmirage.config.openai_batch import OpenAIBatchConfig try: from typing import override # Python 3.12+ @@ -100,32 +100,26 @@ def _setup_batch_runtime(self) -> None: if not provider_cfg_raw.get("enabled", True): return - provider = str(provider_cfg_raw.get("provider", "openai")).strip().lower() - if provider != "openai": - raise ValueError( - f"Only provider='openai' is currently supported, got '{provider}'." - ) - - openai_cfg = OpenAIBatchConfig(**provider_cfg_raw) - self._batch_provider_config = openai_cfg - self._batch_adapter = BatchAdapterFactory.from_config(openai_cfg) + provider_cfg = resolve_single_provider_config(provider_cfg_raw) + self._batch_provider_config = provider_cfg + self._batch_adapter = BatchAdapterFactory.from_config(provider_cfg) run_id = uuid.uuid4().hex[:6] self._text_orchestrator = BatchSubmissionOrchestrator( adapter=self._batch_adapter, config=replace( - openai_cfg, + provider_cfg, metadata_output_path=self._with_metadata_suffix( - openai_cfg.metadata_output_path, "text", run_id + provider_cfg.metadata_output_path, "text", run_id ), ), ) self._multimodal_orchestrator = BatchSubmissionOrchestrator( adapter=self._batch_adapter, config=replace( - openai_cfg, + provider_cfg, metadata_output_path=self._with_metadata_suffix( - openai_cfg.metadata_output_path, "multimodal", run_id + provider_cfg.metadata_output_path, "multimodal", run_id ), ), ) @@ -378,7 +372,7 @@ def _batch_process_sample( jinja_template = jinja2.Template(output_var.prompt) requests: List[Dict[str, Any]] = [] source_indices: List[int] = [] - for local_i, global_i in enumerate(text_only_indices): + for global_i in text_only_indices: base_prompt = jinja_template.render(**batch[global_i].to_dict()) payload = { "messages": [ diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py index 7935df9..deae995 100644 --- a/tests/test_batch_adapter_contracts.py +++ b/tests/test_batch_adapter_contracts.py @@ -1,7 +1,14 @@ +from dataclasses import dataclass + import pytest from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.provider_resolution import ( + BatchProviderConfigRegistry, + resolve_single_provider_config, +) from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry @@ -80,6 +87,13 @@ def clear_batch_adapter_registry(): BatchAdapterRegistry.clear() +@pytest.fixture(autouse=True) +def clear_batch_provider_registry(): + BatchProviderConfigRegistry.clear() + yield + BatchProviderConfigRegistry.clear() + + def test_adapter_interface_is_abstract(): with pytest.raises(TypeError): BatchSubmissionAdapter() @@ -150,3 +164,38 @@ def test_factory_resolves_missing_credentials_from_environment(monkeypatch): assert isinstance(adapter, CredentialedTestAdapter) assert config.credentials["api_key"] == "from-env" + + +@dataclass +class UnitBatchConfig(BatchProviderConfig): + provider: str = "unit" + unit_setting: str = "default" + + def __post_init__(self) -> None: + super().__post_init__() + if not self.unit_setting.strip(): + raise ValueError("unit_setting must be a non-empty string") + + +def test_resolve_single_provider_config_defaults_to_openai(): + config = resolve_single_provider_config({}) + + assert isinstance(config, OpenAIBatchConfig) + assert config.provider == "openai" + + +def test_resolve_single_provider_config_resolves_custom_provider(): + BatchProviderConfigRegistry.register("unit", UnitBatchConfig) + + config = resolve_single_provider_config( + {"provider": "unit", "unit_setting": "custom"} + ) + + assert isinstance(config, UnitBatchConfig) + assert config.provider == "unit" + assert config.unit_setting == "custom" + + +def test_resolve_single_provider_config_raises_for_unknown_provider(): + with pytest.raises(ValueError, match="Unknown batch provider"): + resolve_single_provider_config({"provider": "not-registered"}) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index 30b2f1c..d96e9da 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -1,10 +1,12 @@ import json +from types import SimpleNamespace +import pytest from mmirage.config.openai_batch import OpenAIBatchConfig def test_collect_and_merge_reconstructs_rows_deterministically(tmp_path, monkeypatch): - from mmirage.core.process.batch.collector import collect_and_merge + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -49,47 +51,17 @@ def retrieve_results(self, provider_batch_id, config): return [ { "custom_id": "c1", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"q2","answer":"a2"}' - } - } - ] - } - }, + "generated_text": '{"question":"q2","answer":"a2"}', }, { "custom_id": "c2", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"q0","answer":"a0"}' - } - } - ] - } - }, + "generated_text": '{"question":"q0","answer":"a0"}', }, ] return [ { "custom_id": "c3", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"q1","answer":"a1"}' - } - } - ] - } - }, + "generated_text": '{"question":"q1","answer":"a1"}', } ] @@ -100,8 +72,9 @@ def retrieve_results(self, provider_batch_id, config): ) provider_configs = {"openai": OpenAIBatchConfig(credentials={"api_key": "k"})} + records = _read_metadata_records(str(metadata_path)) rows = collect_and_merge( - metadata_output_path=str(metadata_path), + records=records, provider_configs=provider_configs, output_path=str(output_path), ) @@ -118,7 +91,7 @@ def retrieve_results(self, provider_batch_id, config): def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): - from mmirage.core.process.batch.collector import collect_and_merge + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -134,8 +107,9 @@ def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): ) try: + records = _read_metadata_records(str(metadata_path)) collect_and_merge( - metadata_output_path=str(metadata_path), + records=records, provider_configs={}, output_path=str(tmp_path / "out.jsonl"), ) @@ -145,7 +119,7 @@ def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): def test_collect_and_merge_outputs_caption_for_plain_text_content(tmp_path, monkeypatch): - from mmirage.core.process.batch.collector import collect_and_merge + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -167,17 +141,7 @@ def retrieve_results(self, provider_batch_id, config): return [ { "custom_id": "img_1", - "response": { - "body": { - "choices": [ - { - "message": { - "content": "A black cat sitting on a sofa." - } - } - ] - } - }, + "generated_text": "A black cat sitting on a sofa.", } ] @@ -186,8 +150,9 @@ def retrieve_results(self, provider_batch_id, config): lambda config: FakeAdapter(), ) + records = _read_metadata_records(str(metadata_path)) rows = collect_and_merge( - metadata_output_path=str(metadata_path), + records=records, provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, output_path=str(output_path), ) @@ -199,3 +164,210 @@ def retrieve_results(self, provider_batch_id, config): "caption": "A black cat sitting on a sofa.", } ] + + +def test_collector_main_uses_config_and_records(tmp_path, monkeypatch): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_main", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + captured = {} + + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + def _fake_collect_and_merge(records, provider_configs, output_path_arg): + captured["records"] = records + captured["provider_configs"] = provider_configs + captured["output_path"] = output_path_arg + return [{"source_index": 0, "custom_id": "c1", "caption": "ok"}] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.collect_and_merge", + _fake_collect_and_merge, + ) + + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(captured["records"]) == 1 + assert captured["records"][0]["provider"] == "openai" + assert "openai" in captured["provider_configs"] + assert captured["output_path"] == str(output_path) + + +def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "mistral", + "provider_batch_id": "batch_mistral", + "custom_id_to_source_index": {"m1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + # Config intentionally only defines openai, not mistral. + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + with pytest.raises(ValueError, match="missing from YAML batch_provider config"): + collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + +def test_collect_and_merge_routes_multiple_providers(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + from mmirage.core.process.batch.provider_resolution import resolve_provider_configs + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"o1": 1}, + } + ), + json.dumps( + { + "provider": "unit", + "provider_batch_id": "batch_unit", + "custom_id_to_source_index": {"u1": 0}, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={"provider": "openai", "credentials": {"api_key": "k"}} + ), + SimpleNamespace(batch_provider={"provider": "unit"}), + ] + ) + + records = _read_metadata_records(str(metadata_path)) + provider_configs = resolve_provider_configs(records, cfg) + + class OpenAIAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + return [{"custom_id": "o1", "generated_text": "openai"}] + + class UnitAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + return [{"custom_id": "u1", "generated_text": "unit"}] + + adapters = { + "openai": OpenAIAdapter(), + "unit": UnitAdapter(), + } + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: adapters[config.provider], + ) + + output_path = tmp_path / "merged.jsonl" + rows = collect_and_merge( + records=records, + provider_configs=provider_configs, + output_path=str(output_path), + ) + + assert [row["custom_id"] for row in rows] == ["u1", "o1"] + assert [row["caption"] for row in rows] == ["unit", "openai"] + assert ("batch_openai", "openai") in adapters["openai"].calls + assert ("batch_unit", "unit") in adapters["unit"].calls + + +def test_collector_main_raises_for_invalid_batch_provider_config(tmp_path, monkeypatch): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace(batch_provider={"provider": "openai", "batch_endpoint": "v1"}) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + with pytest.raises(ValueError, match="batch_endpoint must start with '/'"): + collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py index f52afdf..5e9013a 100644 --- a/tests/test_batch_orchestrator.py +++ b/tests/test_batch_orchestrator.py @@ -1,7 +1,14 @@ +from dataclasses import dataclass import json +import pytest + from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.provider_resolution import BatchProviderConfigRegistry +from mmirage.core.process.batch.registry import BatchAdapterRegistry +from mmirage.core.process.base import ProcessorRegistry +from mmirage.core.process.processors.llm.config import SGLangLLMConfig, SGLangServerArgs class RecordingAdapter(BatchSubmissionAdapter): @@ -51,6 +58,15 @@ def retrieve_results(self, provider_batch_id, config): return [] +@pytest.fixture(autouse=True) +def clear_batch_registries(): + BatchProviderConfigRegistry.clear() + BatchAdapterRegistry.clear() + yield + BatchProviderConfigRegistry.clear() + BatchAdapterRegistry.clear() + + def test_orchestrator_buffers_across_iterations_and_avoids_tiny_midstream_flush(tmp_path): from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator @@ -121,3 +137,76 @@ def test_orchestrator_writes_provider_neutral_metadata_with_flush_reason(tmp_pat assert second["flush_reason"] == "finalize" assert second["custom_id_to_source_index"] == {"x2": 1} assert second["provider_batch_id"].startswith("batch-chunk-") + + +@dataclass +class UnitBatchConfig(BatchProviderConfig): + provider: str = "unit" + unit_setting: str = "default" + + def __post_init__(self) -> None: + super().__post_init__() + if not self.unit_setting.strip(): + raise ValueError("unit_setting must be a non-empty string") + + +def test_llm_processor_initializes_with_custom_provider(tmp_path): + BatchProviderConfigRegistry.register("unit", UnitBatchConfig) + BatchAdapterRegistry.register("unit", RecordingAdapter) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider={ + "provider": "unit", + "unit_setting": "custom", + "metadata_output_path": str(tmp_path / "metadata.jsonl"), + }, + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is True + assert isinstance(processor._batch_provider_config, UnitBatchConfig) + assert processor._batch_provider_config.provider == "unit" + assert processor._batch_provider_config.unit_setting == "custom" + assert isinstance(processor._batch_adapter, RecordingAdapter) + + +def test_llm_processor_skips_batch_setup_when_disabled(monkeypatch): + class FakeEngine: + def __init__(self, **_kwargs): + return None + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation should not run in this test") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, *args, **kwargs): + return "" + + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider={"enabled": False}, + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is False + assert processor._batch_adapter is None + assert processor._batch_provider_config is None diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index 80ffbbd..c0ce83b 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -1,11 +1,15 @@ from io import StringIO +from types import SimpleNamespace from mmirage.config.openai_batch import OpenAIBatchConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult def test_extract_unique_provider_batches_handles_malformed_and_duplicates(tmp_path): - from mmirage.core.process.batch.status_checker import extract_unique_provider_batches + from mmirage.core.process.batch.status_checker import ( + _read_metadata_records, + extract_unique_provider_batches, + ) metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -22,13 +26,13 @@ def test_extract_unique_provider_batches_handles_malformed_and_duplicates(tmp_pa encoding="utf-8", ) - pairs = extract_unique_provider_batches(str(metadata_path)) + pairs = extract_unique_provider_batches(_read_metadata_records(str(metadata_path))) assert pairs == [("openai", "batch_1"), ("openai", "batch_2")] def test_run_status_checker_prints_summary_with_factory_dispatch(tmp_path, monkeypatch): - from mmirage.core.process.batch.status_checker import run_status_checker + from mmirage.core.process.batch.status_checker import _read_metadata_records, run_status_checker metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -67,9 +71,10 @@ def check_batch_status(self, provider_batch_id, config): config_map = { "openai": OpenAIBatchConfig(credentials={"api_key": "k"}), } + records = _read_metadata_records(str(metadata_path)) results = run_status_checker( - metadata_output_path=str(metadata_path), + metadata_records=records, provider_configs=config_map, output=output, ) @@ -86,3 +91,110 @@ def check_batch_status(self, provider_batch_id, config): printed = output.getvalue() assert "Batch batch_1 (openai): completed" in printed assert "Batch batch_2 (openai): in_progress" in printed + + +def test_status_checker_main_uses_config_and_runs(tmp_path, monkeypatch): + from mmirage.core.process.batch import status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + called = {} + + def _fake_run_status_checker(metadata_records, provider_configs, output=None): + called["metadata_records"] = metadata_records + called["provider_configs"] = provider_configs + return [] + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.run_status_checker", + _fake_run_status_checker, + ) + + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(called["metadata_records"]) == 1 + assert called["metadata_records"][0]["provider"] == "openai" + assert "openai" in called["provider_configs"] + + +def test_status_checker_main_returns_error_when_metadata_provider_missing_in_config( + tmp_path, monkeypatch, capsys +): + from mmirage.core.process.batch import status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"mistral","provider_batch_id":"batch_m1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + # Config intentionally only defines openai, not mistral. + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 1 + stderr = capsys.readouterr().err + assert "Status checker failed:" in stderr + assert "missing from YAML batch_provider config" in stderr + + +def test_status_checker_main_returns_error_when_credentials_missing( + tmp_path, monkeypatch, capsys +): + from mmirage.core.process.batch import status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[SimpleNamespace(batch_provider={"provider": "openai", "credentials": {}})] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 1 + stderr = capsys.readouterr().err + assert "Status checker failed:" in stderr + assert "Missing credentials for provider 'openai'" in stderr diff --git a/tests/test_integration_receiver.py b/tests/test_integration_receiver.py index 0deb13f..7e44692 100644 --- a/tests/test_integration_receiver.py +++ b/tests/test_integration_receiver.py @@ -4,7 +4,7 @@ def test_integration_receiver_reads_receipt_and_writes_merged_output(tmp_path, monkeypatch): - from mmirage.core.process.batch.collector import collect_and_merge + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge metadata_path = tmp_path / "receipt.text.jsonl" metadata_path.write_text( @@ -37,47 +37,17 @@ def retrieve_results(self, provider_batch_id, config): return [ { "custom_id": "id_a", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"What is id_a?","answer":"one"}' - } - } - ] - } - }, + "generated_text": '{"question":"What is id_a?","answer":"one"}', }, { "custom_id": "id_b", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"What is id_b?","answer":"zero"}' - } - } - ] - } - }, + "generated_text": '{"question":"What is id_b?","answer":"zero"}', }, ] return [ { "custom_id": "id_c", - "response": { - "body": { - "choices": [ - { - "message": { - "content": '{"question":"What is id_c?","answer":"two"}' - } - } - ] - } - }, + "generated_text": '{"question":"What is id_c?","answer":"two"}', } ] @@ -86,8 +56,9 @@ def retrieve_results(self, provider_batch_id, config): lambda config: FakeAdapter(), ) + records = _read_metadata_records(str(metadata_path)) rows = collect_and_merge( - metadata_output_path=str(metadata_path), + records=records, provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "test"})}, output_path=str(output_path), ) From 7ff6e1088be3684029e09d406dc3df8b13776cf0 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sat, 2 May 2026 15:07:53 +0200 Subject: [PATCH 26/45] abstract method added in BaseProcessor --- src/mmirage/core/process/base.py | 4 ++++ src/mmirage/core/process/mapper.py | 4 +--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index f374e12..95a26fa 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -62,6 +62,10 @@ def batch_process_sample( NotImplementedError: If not implemented by subclass. """ raise NotImplementedError() + + def finalize(self) -> None: + """Optional lifecycle hook; override when a processor buffers state.""" + pass class ProcessorRegistry: diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 69573a3..cf81455 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -107,6 +107,4 @@ def rewrite_batch( def finalize_processors(self) -> None: """Finalize processors that expose a finalize lifecycle hook.""" for processor in self.processors.values(): - finalize = getattr(processor, "finalize", None) - if callable(finalize): - finalize() + processor.finalize() From c9c43e95956e056341eefccf69dc89903d3e1230 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sat, 2 May 2026 17:17:45 +0200 Subject: [PATCH 27/45] small update on the completion_window of openai --- src/mmirage/config/openai_batch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mmirage/config/openai_batch.py b/src/mmirage/config/openai_batch.py index 65166af..8f233e1 100644 --- a/src/mmirage/config/openai_batch.py +++ b/src/mmirage/config/openai_batch.py @@ -22,12 +22,15 @@ class OpenAIBatchConfig(BatchProviderConfig): provider: str = "openai" model: str = "gpt-4.1-mini" batch_endpoint: str = "/v1/chat/completions" - completion_window: Literal["24h"] = "24h" + completion_window: str = "24h" base_url: Optional[str] = None metadata: Dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: super().__post_init__() + allowed_windows = {"24h"} + if self.completion_window not in allowed_windows: + raise ValueError(f"completion_window must be one of {allowed_windows}") if not self.model.strip(): raise ValueError("model must be a non-empty string") From 69beb2fa7fb4a322d6e9c42028edc98a25b92690 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Sat, 2 May 2026 23:01:55 +0200 Subject: [PATCH 28/45] adjust adapter.py thanks to the comments of @fabnemEPFL --- src/mmirage/core/process/batch/adapter.py | 24 +------------------ .../core/process/batch/openai_adapter.py | 11 +-------- .../core/process/batch/orchestrator.py | 1 - tests/mock_data_vision/data.jsonl | 3 +-- tests/test_batch_adapter_contracts.py | 3 --- tests/test_batch_chunking.py | 2 -- tests/test_batch_orchestrator.py | 11 --------- tests/test_batch_status_checker.py | 1 - tests/test_openai_batch_adapter.py | 2 -- 9 files changed, 3 insertions(+), 55 deletions(-) diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py index 1a71805..ff0126e 100644 --- a/src/mmirage/core/process/batch/adapter.py +++ b/src/mmirage/core/process/batch/adapter.py @@ -18,13 +18,11 @@ class BatchSubmissionResult: Attributes: provider_batch_id: Provider-side identifier for the submitted job/batch. status: Provider submission status normalized to a short string. - submitted_request_count: Number of requests accepted in this submission. raw_response: Original provider response payload for traceability. """ provider_batch_id: str status: str - submitted_request_count: int raw_response: Mapping[str, Any] = field(default_factory=dict) @@ -37,31 +35,11 @@ class BatchSubmissionAdapter(abc.ABC): required_credentials: Tuple[str, ...] = tuple() - @property - @abc.abstractmethod - def adapter_name(self) -> str: - """Return a stable adapter identity string. - - The identity should remain stable across code changes that preserve - behavior and should change only when semantics diverge. - """ - raise NotImplementedError() - - @property - @abc.abstractmethod - def adapter_version(self) -> str: - """Return the adapter implementation version. - - This value is persisted in metadata artifacts to support auditing and - replay diagnostics across code revisions. - """ - raise NotImplementedError() - @abc.abstractmethod def build_request( self, custom_id: str, - payload: Mapping[str, Any], + payload: Dict[str, Any], config: BatchProviderConfig, ) -> Mapping[str, Any]: """Build a single provider-ready request object. diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 8ce32c9..3bfec7b 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -20,18 +20,10 @@ class OpenAIBatchAdapter(BatchSubmissionAdapter): required_credentials = ("api_key",) - @property - def adapter_name(self) -> str: - return "openai-batch-adapter" - - @property - def adapter_version(self) -> str: - return "1.0.0" - def build_request( self, custom_id: str, - payload: Mapping[str, Any], + payload: Dict[str, Any], config: BatchProviderConfig, ) -> Mapping[str, Any]: openai_config = self._as_openai_config(config) @@ -206,7 +198,6 @@ def parse_submission_result( return BatchSubmissionResult( provider_batch_id=batch_id, status=status, - submitted_request_count=request_count, raw_response=dict(raw_result), ) diff --git a/src/mmirage/core/process/batch/orchestrator.py b/src/mmirage/core/process/batch/orchestrator.py index c2a4f88..6d1dbfd 100644 --- a/src/mmirage/core/process/batch/orchestrator.py +++ b/src/mmirage/core/process/batch/orchestrator.py @@ -167,7 +167,6 @@ def _persist_metadata( metadata_record: Dict[str, Any] = { "provider": self.config.provider, - "adapter_version": self.adapter.adapter_version, "chunk_id": chunk_id, "provider_batch_id": parsed_result.provider_batch_id, "status": parsed_result.status, diff --git a/tests/mock_data_vision/data.jsonl b/tests/mock_data_vision/data.jsonl index 0b29b01..7f67b97 100644 --- a/tests/mock_data_vision/data.jsonl +++ b/tests/mock_data_vision/data.jsonl @@ -8,5 +8,4 @@ {"image": "beach.jpg"} {"image": "cat.jpg"} {"image": "dog.jpg"} -{"image": "mountain.jpg"} -{"image": "beach.jpg"} + diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py index deae995..70887ab 100644 --- a/tests/test_batch_adapter_contracts.py +++ b/tests/test_batch_adapter_contracts.py @@ -41,7 +41,6 @@ def parse_submission_result(self, raw_result, request_count): return BatchSubmissionResult( provider_batch_id=str(raw_result["batch_id"]), status=str(raw_result["status"]), - submitted_request_count=request_count, raw_response=raw_result, ) @@ -49,7 +48,6 @@ def check_batch_status(self, provider_batch_id, config): return BatchSubmissionResult( provider_batch_id=provider_batch_id, status="submitted", - submitted_request_count=0, raw_response={"id": provider_batch_id, "status": "submitted"}, ) @@ -119,7 +117,6 @@ def test_complete_adapter_is_interface_compliant(): assert parsed.provider_batch_id == "unit-chunk-1" assert parsed.status == "submitted" - assert parsed.submitted_request_count == 1 def test_factory_resolves_registered_provider(): diff --git a/tests/test_batch_chunking.py b/tests/test_batch_chunking.py index 753cd9a..6451362 100644 --- a/tests/test_batch_chunking.py +++ b/tests/test_batch_chunking.py @@ -31,7 +31,6 @@ def parse_submission_result(self, raw_result, request_count): return BatchSubmissionResult( provider_batch_id=str(raw_result["id"]), status=str(raw_result["status"]), - submitted_request_count=request_count, raw_response=raw_result, ) @@ -39,7 +38,6 @@ def check_batch_status(self, provider_batch_id, config): return BatchSubmissionResult( provider_batch_id=provider_batch_id, status="submitted", - submitted_request_count=0, raw_response={"id": provider_batch_id, "status": "submitted"}, ) diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py index 5e9013a..2629e27 100644 --- a/tests/test_batch_orchestrator.py +++ b/tests/test_batch_orchestrator.py @@ -15,14 +15,6 @@ class RecordingAdapter(BatchSubmissionAdapter): def __init__(self) -> None: self.submissions = [] - @property - def adapter_name(self) -> str: - return "recording-adapter" - - @property - def adapter_version(self) -> str: - return "1.2.3" - def build_request(self, custom_id, payload, config): return {"custom_id": custom_id, **dict(payload)} @@ -42,7 +34,6 @@ def parse_submission_result(self, raw_result, request_count): return BatchSubmissionResult( provider_batch_id=str(raw_result["id"]), status=str(raw_result["status"]), - submitted_request_count=request_count, raw_response=raw_result, ) @@ -50,7 +41,6 @@ def check_batch_status(self, provider_batch_id, config): return BatchSubmissionResult( provider_batch_id=provider_batch_id, status="submitted", - submitted_request_count=0, raw_response={"id": provider_batch_id, "status": "submitted"}, ) @@ -129,7 +119,6 @@ def test_orchestrator_writes_provider_neutral_metadata_with_flush_reason(tmp_pat second = json.loads(lines[1]) assert first["provider"] == "unit" - assert first["adapter_version"] == "1.2.3" assert first["flush_reason"] == "full_chunk" assert first["custom_id_to_source_index"] == {"x1": 0} assert isinstance(first["request_hash"], str) and len(first["request_hash"]) == 64 diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index c0ce83b..0aa52c3 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -56,7 +56,6 @@ def check_batch_status(self, provider_batch_id, config): return BatchSubmissionResult( provider_batch_id=provider_batch_id, status=status, - submitted_request_count=0, raw_response={"id": provider_batch_id, "status": status}, ) diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index a11739d..1aa05b2 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -199,7 +199,6 @@ def test_openai_parse_submission_result_normalizes_payload(): assert isinstance(result, BatchSubmissionResult) assert result.provider_batch_id == "batch_123" assert result.status == "in_progress" - assert result.submitted_request_count == 4 assert result.raw_response == raw @@ -255,7 +254,6 @@ def __init__(self, **kwargs): assert isinstance(result, BatchSubmissionResult) assert result.provider_batch_id == "batch_456" assert result.status == "completed" - assert result.submitted_request_count == 0 def test_openai_check_batch_status_falls_back_to_env_api_key(monkeypatch): From ddca511dd79a57974f84c9bd0a05b243d1a0c7d1 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 5 May 2026 22:24:58 +0200 Subject: [PATCH 29/45] make --metadata_path optional as already in the config file + tests --- src/mmirage/core/process/batch/collector.py | 27 ++++++-- .../core/process/batch/status_checker.py | 38 +++++++++--- tests/test_batch_collector.py | 62 +++++++++++++++++++ tests/test_batch_status_checker.py | 45 ++++++++++++++ 4 files changed, 159 insertions(+), 13 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index 0f42303..f388f09 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -14,7 +14,10 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple from mmirage.config.batch_provider import BatchProviderConfig -from mmirage.core.process.batch.provider_resolution import resolve_provider_configs +from mmirage.core.process.batch.provider_resolution import ( + build_all_provider_configs, + resolve_provider_configs, +) from mmirage.core.process.batch.registry import BatchAdapterFactory @@ -193,8 +196,10 @@ def _build_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--metadata-path", nargs="+", - required=True, - help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.", + help=( + "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " + "When omitted, uses metadata_output_path from the config batch_provider blocks." + ), ) parser.add_argument( "--output-path", @@ -219,8 +224,22 @@ def main(argv: Sequence[str] | None = None) -> int: args = _build_arg_parser().parse_args(argv) from mmirage.config.utils import load_mmirage_config - records = _read_metadata_records(args.metadata_path) cfg = load_mmirage_config(args.config) + if args.metadata_path: + metadata_paths = args.metadata_path + else: + all_provider_configs = build_all_provider_configs(cfg) + metadata_paths = [ + config.metadata_output_path + for config in all_provider_configs.values() + if config.metadata_output_path + ] + metadata_paths = list(dict.fromkeys(metadata_paths)) + + if not metadata_paths: + raise ValueError("No metadata paths provided and none found in config batch_provider blocks.") + + records = _read_metadata_records(metadata_paths) provider_configs = resolve_provider_configs(records, cfg) rows = collect_and_merge(records, provider_configs, args.output_path) diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index 6b90760..2045f98 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -13,7 +13,10 @@ from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult -from mmirage.core.process.batch.provider_resolution import resolve_provider_configs +from mmirage.core.process.batch.provider_resolution import ( + build_all_provider_configs, + resolve_provider_configs, +) from mmirage.core.process.batch.registry import BatchAdapterFactory @@ -116,8 +119,10 @@ def _build_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--metadata-path", nargs="+", - required=True, - help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.", + help=( + "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " + "When omitted, uses metadata_output_path from the config batch_provider blocks." + ), ) parser.add_argument( "--config", @@ -136,14 +141,29 @@ def main(argv: Sequence[str] | None = None) -> int: args = _build_arg_parser().parse_args(argv) from mmirage.config.utils import load_mmirage_config - records = _read_metadata_records(args.metadata_path) - pairs = extract_unique_provider_batches(records) - if not pairs: - print(f"No provider batch IDs found in metadata file: {args.metadata_path}") - return 0 - try: cfg = load_mmirage_config(args.config) + if args.metadata_path: + metadata_paths = args.metadata_path + else: + all_provider_configs = build_all_provider_configs(cfg) + metadata_paths = [ + config.metadata_output_path + for config in all_provider_configs.values() + if config.metadata_output_path + ] + metadata_paths = list(dict.fromkeys(metadata_paths)) + + if not metadata_paths: + print("No metadata paths provided and none found in config batch_provider blocks.") + return 1 + + records = _read_metadata_records(metadata_paths) + pairs = extract_unique_provider_batches(records) + if not pairs: + print(f"No provider batch IDs found in metadata file(s): {metadata_paths}") + return 0 + provider_configs = resolve_provider_configs(records, cfg) if not provider_configs: print("No supported provider configurations could be built from metadata.") diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index d96e9da..5444068 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -219,6 +219,68 @@ def _fake_collect_and_merge(records, provider_configs, output_path_arg): assert captured["output_path"] == str(output_path) +def test_collector_main_uses_config_metadata_path_when_missing_cli_arg( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_main", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_path), + } + ) + ] + ) + captured = {} + + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + def _fake_collect_and_merge(records, provider_configs, output_path_arg): + captured["records"] = records + captured["provider_configs"] = provider_configs + captured["output_path"] = output_path_arg + return [{"source_index": 0, "custom_id": "c1", "caption": "ok"}] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.collect_and_merge", + _fake_collect_and_merge, + ) + + rc = collector.main( + [ + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(captured["records"]) == 1 + assert captured["records"][0]["provider"] == "openai" + assert "openai" in captured["provider_configs"] + assert captured["output_path"] == str(output_path) + + def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch): from mmirage.core.process.batch import collector diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index 0aa52c3..811f947 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -197,3 +197,48 @@ def test_status_checker_main_returns_error_when_credentials_missing( stderr = capsys.readouterr().err assert "Status checker failed:" in stderr assert "Missing credentials for provider 'openai'" in stderr + + +def test_status_checker_main_uses_config_metadata_path_when_missing_cli_arg( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_path), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + called = {} + + def _fake_run_status_checker(metadata_records, provider_configs, output=None): + called["metadata_records"] = metadata_records + called["provider_configs"] = provider_configs + return [] + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.run_status_checker", + _fake_run_status_checker, + ) + + rc = status_checker.main(["--config", str(config_path)]) + + assert rc == 0 + assert len(called["metadata_records"]) == 1 + assert called["metadata_records"][0]["provider"] == "openai" + assert "openai" in called["provider_configs"] From 473c9065cda4de20c0926c2dfce661bdec088e54 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Wed, 6 May 2026 00:23:02 +0200 Subject: [PATCH 30/45] use of a batch receipt base path --- src/mmirage/config/batch_provider.py | 4 +- src/mmirage/core/process/batch/collector.py | 9 +++- .../core/process/batch/metadata_paths.py | 41 +++++++++++++++++++ .../core/process/batch/status_checker.py | 8 +++- tests/test_batch_collector.py | 38 ++++++++++++++++- tests/test_batch_status_checker.py | 34 ++++++++++++++- 6 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 src/mmirage/core/process/batch/metadata_paths.py diff --git a/src/mmirage/config/batch_provider.py b/src/mmirage/config/batch_provider.py index 78c399e..12e0305 100644 --- a/src/mmirage/config/batch_provider.py +++ b/src/mmirage/config/batch_provider.py @@ -46,7 +46,9 @@ class BatchProviderConfig: Defaults to 50 MB. max_requests_per_chunk: Optional hard cap on number of requests in a chunk. If None, no request-count cap is enforced. - metadata_output_path: Path where submission metadata artifacts are saved. + metadata_output_path: Base path where submission metadata receipts are saved. + Submission writes suffixed files like ``.text..jsonl`` and + ``.multimodal..jsonl`` from this base path. retry_policy: Retry policy used by the shared batch layer. oversized_request_policy: Handling policy when a single request exceeds ``max_chunk_bytes``. ``isolate`` creates a dedicated oversized diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index f388f09..d30c277 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -14,6 +14,7 @@ from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config from mmirage.core.process.batch.provider_resolution import ( build_all_provider_configs, resolve_provider_configs, @@ -198,7 +199,8 @@ def _build_arg_parser() -> argparse.ArgumentParser: nargs="+", help=( "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " - "When omitted, uses metadata_output_path from the config batch_provider blocks." + "When omitted, uses metadata_output_path from the config batch_provider blocks " + "and resolves suffixed receipts like '.text..jsonl'." ), ) parser.add_argument( @@ -235,6 +237,11 @@ def main(argv: Sequence[str] | None = None) -> int: if config.metadata_output_path ] metadata_paths = list(dict.fromkeys(metadata_paths)) + if not metadata_paths: + raise ValueError( + "No metadata paths provided and none found in config batch_provider blocks." + ) + metadata_paths = resolve_metadata_paths_from_config(metadata_paths) if not metadata_paths: raise ValueError("No metadata paths provided and none found in config batch_provider blocks.") diff --git a/src/mmirage/core/process/batch/metadata_paths.py b/src/mmirage/core/process/batch/metadata_paths.py new file mode 100644 index 0000000..0915617 --- /dev/null +++ b/src/mmirage/core/process/batch/metadata_paths.py @@ -0,0 +1,41 @@ +"""Helpers for resolving batch metadata receipt paths.""" + +from __future__ import annotations + +import glob +from typing import List, Sequence + +_METADATA_SUFFIXES = ("text", "multimodal") + + +def _base_path_to_patterns(base_path: str) -> List[str]: + trimmed = base_path[:-6] if base_path.endswith(".jsonl") else base_path + return [f"{trimmed}.{suffix}.*.jsonl" for suffix in _METADATA_SUFFIXES] + + +def resolve_metadata_paths_from_config(metadata_output_paths: Sequence[str]) -> List[str]: + """Return metadata receipt paths for config-provided base paths. + + Submission writes suffixed receipts using .text..jsonl and + .multimodal..jsonl, so we expand base paths into matching globs. + """ + patterns: List[str] = [] + resolved: List[str] = [] + + for base_path in metadata_output_paths: + for pattern in _base_path_to_patterns(base_path): + patterns.append(pattern) + matches = sorted(glob.glob(pattern)) + if matches: + resolved.extend(matches) + + resolved = list(dict.fromkeys(resolved)) + if not resolved: + pattern_list = ", ".join(patterns) if patterns else "" + raise ValueError( + "No metadata receipts matched config metadata_output_path patterns. " + f"Tried: {pattern_list}. Expected suffixed files like " + "'.text..jsonl' or '.multimodal..jsonl'." + ) + + return resolved diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index 2045f98..43ae1e5 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -13,6 +13,7 @@ from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config from mmirage.core.process.batch.provider_resolution import ( build_all_provider_configs, resolve_provider_configs, @@ -121,7 +122,8 @@ def _build_arg_parser() -> argparse.ArgumentParser: nargs="+", help=( "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " - "When omitted, uses metadata_output_path from the config batch_provider blocks." + "When omitted, uses metadata_output_path from the config batch_provider blocks " + "and resolves suffixed receipts like '.text..jsonl'." ), ) parser.add_argument( @@ -153,6 +155,10 @@ def main(argv: Sequence[str] | None = None) -> int: if config.metadata_output_path ] metadata_paths = list(dict.fromkeys(metadata_paths)) + if not metadata_paths: + print("No metadata paths provided and none found in config batch_provider blocks.") + return 1 + metadata_paths = resolve_metadata_paths_from_config(metadata_paths) if not metadata_paths: print("No metadata paths provided and none found in config batch_provider blocks.") diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index 5444068..efac86a 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -224,7 +224,8 @@ def test_collector_main_uses_config_metadata_path_when_missing_cli_arg( ): from mmirage.core.process.batch import collector - metadata_path = tmp_path / "receipts.jsonl" + metadata_base = tmp_path / "batch_metadata.jsonl" + metadata_path = tmp_path / "batch_metadata.text.abc123.jsonl" metadata_path.write_text( json.dumps( { @@ -245,7 +246,7 @@ def test_collector_main_uses_config_metadata_path_when_missing_cli_arg( SimpleNamespace( batch_provider={ "provider": "openai", - "metadata_output_path": str(metadata_path), + "metadata_output_path": str(metadata_base), } ) ] @@ -281,6 +282,39 @@ def _fake_collect_and_merge(records, provider_configs, output_path_arg): assert captured["output_path"] == str(output_path) +def test_collector_main_raises_when_config_metadata_paths_missing(tmp_path, monkeypatch): + from mmirage.core.process.batch import collector + + metadata_base = tmp_path / "batch_metadata.jsonl" + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + with pytest.raises( + ValueError, match="No metadata receipts matched config metadata_output_path patterns" + ): + collector.main( + [ + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch): from mmirage.core.process.batch import collector diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index 811f947..c9c8d8e 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -204,7 +204,8 @@ def test_status_checker_main_uses_config_metadata_path_when_missing_cli_arg( ): from mmirage.core.process.batch import status_checker - metadata_path = tmp_path / "receipts.jsonl" + metadata_base = tmp_path / "batch_metadata.jsonl" + metadata_path = tmp_path / "batch_metadata.text.abc123.jsonl" metadata_path.write_text( '{"provider":"openai","provider_batch_id":"batch_1"}\n', encoding="utf-8", @@ -217,7 +218,7 @@ def test_status_checker_main_uses_config_metadata_path_when_missing_cli_arg( SimpleNamespace( batch_provider={ "provider": "openai", - "metadata_output_path": str(metadata_path), + "metadata_output_path": str(metadata_base), } ) ] @@ -242,3 +243,32 @@ def _fake_run_status_checker(metadata_records, provider_configs, output=None): assert len(called["metadata_records"]) == 1 assert called["metadata_records"][0]["provider"] == "openai" assert "openai" in called["provider_configs"] + + +def test_status_checker_main_returns_error_when_config_metadata_paths_missing( + tmp_path, monkeypatch, capsys +): + from mmirage.core.process.batch import status_checker + + metadata_base = tmp_path / "batch_metadata.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + rc = status_checker.main(["--config", str(config_path)]) + + assert rc == 1 + stderr = capsys.readouterr().err + assert "Status checker failed:" in stderr + assert "No metadata receipts matched config metadata_output_path patterns" in stderr From 92219ac5050594d0782e58b28957cd35ac0910ac Mon Sep 17 00:00:00 2001 From: legstar67 Date: Wed, 6 May 2026 07:57:23 +0200 Subject: [PATCH 31/45] solve problem of parsing the openai output --- .../core/process/batch/openai_adapter.py | 36 ++++++++++++++- tests/test_openai_batch_adapter.py | 44 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 3bfec7b..ea7dd13 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -183,7 +183,12 @@ def retrieve_results( raw = line.strip() if not raw: continue - rows.append(dict(json.loads(raw))) + row = dict(json.loads(raw)) + if "generated_text" not in row: + generated_text = self._extract_generated_text(row) + if generated_text: + row["generated_text"] = generated_text + rows.append(row) return rows @@ -207,6 +212,35 @@ def _as_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: return config raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") + @staticmethod + def _extract_generated_text(row: Mapping[str, Any]) -> str: + response = row.get("response") + if not isinstance(response, Mapping): + return "" + + body = response.get("body") + if not isinstance(body, Mapping): + return "" + + choices = body.get("choices") + if isinstance(choices, list) and choices: + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + text = first.get("text") + if isinstance(text, str): + return text + + text = body.get("text") + if isinstance(text, str): + return text + + return "" + @staticmethod def _create_client(config: OpenAIBatchConfig) -> OpenAI: api_key = (config.credentials.get("api_key", "") or "").strip() diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index 1aa05b2..06356d3 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -352,6 +352,50 @@ def __init__(self, **kwargs): assert rows[1]["custom_id"] == "c2" assert rows[0]["response"]["body"]["text"] == "A" assert rows[1]["response"]["body"]["text"] == "B" + assert rows[0]["generated_text"] == "A" + assert rows[1]["generated_text"] == "B" + + +def test_openai_retrieve_results_prefers_message_content(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_choices" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"choices":[' + '{"message":{"content":"{\\"question\\":\\"Q\\",\\"answer\\":\\"A\\"}"}}' + ']}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_choices", config=config) + + assert rows[0]["custom_id"] == "c1" + assert rows[0]["generated_text"] == '{"question":"Q","answer":"A"}' def test_openai_retrieve_results_raises_if_batch_not_completed(monkeypatch): From 16b270ab7d7f45cf5c37425632bab567d2b58927 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 11 May 2026 20:49:12 +0200 Subject: [PATCH 32/45] upgrade of status checker according to the comments --- .../core/process/batch/status_checker.py | 38 ++++----- tests/test_batch_status_checker.py | 78 +++++++++---------- 2 files changed, 59 insertions(+), 57 deletions(-) diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index 43ae1e5..3ca0cdd 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -20,6 +20,9 @@ ) from mmirage.core.process.batch.registry import BatchAdapterFactory +import logging +logger = logging.getLogger(__name__) + def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: if isinstance(metadata_paths, str): @@ -77,7 +80,6 @@ def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any] def run_status_checker( metadata_records: Sequence[Mapping[str, Any]], provider_configs: Mapping[str, BatchProviderConfig], - output: TextIO = sys.stdout, ) -> List[BatchSubmissionResult]: """Check batch status for each referenced provider batch. @@ -90,26 +92,26 @@ def run_status_checker( for provider, provider_batch_id in extract_unique_provider_batches(metadata_records): if provider not in provider_configs: - print(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.", file=output) + logger.info(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.") provider_counts = counter.setdefault(provider, {}) provider_counts["skipped"] = provider_counts.get("skipped", 0) + 1 - continue - config = provider_configs[provider] - adapter = BatchAdapterFactory.from_config(config) - result = adapter.check_batch_status(provider_batch_id=provider_batch_id, config=config) - results.append(result) + else: + config = provider_configs[provider] + adapter = BatchAdapterFactory.from_config(config) + result = adapter.check_batch_status(provider_batch_id=provider_batch_id, config=config) + results.append(result) - print(f"Batch {provider_batch_id} ({provider}): {result.status}", file=output) - provider_counts = counter.setdefault(provider, {}) - provider_counts[result.status] = provider_counts.get(result.status, 0) + 1 + logger.info(f"Batch {provider_batch_id} ({provider}): {result.status}") + provider_counts = counter.setdefault(provider, {}) + provider_counts[result.status] = provider_counts.get(result.status, 0) + 1 - print("\n------------ Batch status summary ------------", file=output) + logger.info("------------ Batch status summary ------------") for provider, status_counts in counter.items(): - print(f"Total batches for provider '{provider}':", file=output) total = sum(status_counts.values()) + logger.info(f"Provider '{provider}' (Total: {total}):") for status, count in status_counts.items(): - print(f" {status}: {count}/{total}", file=output) + logger.info(f" - {status}: {count}/{total}") return results @@ -156,30 +158,30 @@ def main(argv: Sequence[str] | None = None) -> int: ] metadata_paths = list(dict.fromkeys(metadata_paths)) if not metadata_paths: - print("No metadata paths provided and none found in config batch_provider blocks.") + logger.error("No metadata paths provided and none found in config batch_provider blocks.") return 1 metadata_paths = resolve_metadata_paths_from_config(metadata_paths) if not metadata_paths: - print("No metadata paths provided and none found in config batch_provider blocks.") + logger.error("No metadata paths provided and none found in config batch_provider blocks.") return 1 records = _read_metadata_records(metadata_paths) pairs = extract_unique_provider_batches(records) if not pairs: - print(f"No provider batch IDs found in metadata file(s): {metadata_paths}") + logger.info(f"No provider batch IDs found in metadata file(s): {metadata_paths}") return 0 provider_configs = resolve_provider_configs(records, cfg) if not provider_configs: - print("No supported provider configurations could be built from metadata.") + logger.error("No supported provider configurations could be built from metadata.") return 1 run_status_checker( metadata_records=records, provider_configs=provider_configs, ) except Exception as exc: - print(f"Status checker failed: {exc}", file=sys.stderr) + logger.exception("Status checker failed") return 1 return 0 diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index c9c8d8e..3e601dd 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -1,4 +1,3 @@ -from io import StringIO from types import SimpleNamespace from mmirage.config.openai_batch import OpenAIBatchConfig @@ -33,6 +32,7 @@ def test_extract_unique_provider_batches_handles_malformed_and_duplicates(tmp_pa def test_run_status_checker_prints_summary_with_factory_dispatch(tmp_path, monkeypatch): from mmirage.core.process.batch.status_checker import _read_metadata_records, run_status_checker + from unittest.mock import patch metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -66,17 +66,16 @@ def check_batch_status(self, provider_batch_id, config): lambda config: fake_adapter, ) - output = StringIO() config_map = { "openai": OpenAIBatchConfig(credentials={"api_key": "k"}), } records = _read_metadata_records(str(metadata_path)) - results = run_status_checker( - metadata_records=records, - provider_configs=config_map, - output=output, - ) + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + results = run_status_checker( + metadata_records=records, + provider_configs=config_map, + ) assert [(r.provider_batch_id, r.status) for r in results] == [ ("batch_1", "completed"), @@ -87,9 +86,10 @@ def check_batch_status(self, provider_batch_id, config): ("batch_2", "openai"), ] - printed = output.getvalue() - assert "Batch batch_1 (openai): completed" in printed - assert "Batch batch_2 (openai): in_progress" in printed + # Verify logger.info was called with expected messages + logger_calls = [call[0][0] for call in mock_logger.info.call_args_list] + assert any("Batch batch_1 (openai): completed" in str(call) for call in logger_calls) + assert any("Batch batch_2 (openai): in_progress" in str(call) for call in logger_calls) def test_status_checker_main_uses_config_and_runs(tmp_path, monkeypatch): @@ -134,9 +134,10 @@ def _fake_run_status_checker(metadata_records, provider_configs, output=None): def test_status_checker_main_returns_error_when_metadata_provider_missing_in_config( - tmp_path, monkeypatch, capsys + tmp_path, monkeypatch ): from mmirage.core.process.batch import status_checker + from unittest.mock import patch metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -150,25 +151,25 @@ def test_status_checker_main_returns_error_when_metadata_provider_missing_in_con cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) - rc = status_checker.main( - [ - "--metadata-path", - str(metadata_path), - "--config", - str(config_path), - ] - ) + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) assert rc == 1 - stderr = capsys.readouterr().err - assert "Status checker failed:" in stderr - assert "missing from YAML batch_provider config" in stderr + assert mock_logger.error.called or mock_logger.exception.called def test_status_checker_main_returns_error_when_credentials_missing( - tmp_path, monkeypatch, capsys + tmp_path, monkeypatch ): from mmirage.core.process.batch import status_checker + from unittest.mock import patch metadata_path = tmp_path / "receipts.jsonl" metadata_path.write_text( @@ -184,19 +185,18 @@ def test_status_checker_main_returns_error_when_credentials_missing( monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) monkeypatch.delenv("OPENAI_API_KEY", raising=False) - rc = status_checker.main( - [ - "--metadata-path", - str(metadata_path), - "--config", - str(config_path), - ] - ) + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) assert rc == 1 - stderr = capsys.readouterr().err - assert "Status checker failed:" in stderr - assert "Missing credentials for provider 'openai'" in stderr + assert mock_logger.error.called or mock_logger.exception.called def test_status_checker_main_uses_config_metadata_path_when_missing_cli_arg( @@ -246,9 +246,10 @@ def _fake_run_status_checker(metadata_records, provider_configs, output=None): def test_status_checker_main_returns_error_when_config_metadata_paths_missing( - tmp_path, monkeypatch, capsys + tmp_path, monkeypatch ): from mmirage.core.process.batch import status_checker + from unittest.mock import patch metadata_base = tmp_path / "batch_metadata.jsonl" config_path = tmp_path / "dummy.yaml" @@ -266,9 +267,8 @@ def test_status_checker_main_returns_error_when_config_metadata_paths_missing( ) monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) - rc = status_checker.main(["--config", str(config_path)]) + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main(["--config", str(config_path)]) assert rc == 1 - stderr = capsys.readouterr().err - assert "Status checker failed:" in stderr - assert "No metadata receipts matched config metadata_output_path patterns" in stderr + assert mock_logger.exception.called From af013e56b6cb0577d5f15cc8355e3674bb98342d Mon Sep 17 00:00:00 2001 From: legstar67 Date: Mon, 11 May 2026 23:23:22 +0200 Subject: [PATCH 33/45] modif according to comments : optimization + deduplication --- src/mmirage/core/process/batch/collector.py | 35 +-------------- .../core/process/batch/metadata_utils.py | 44 +++++++++++++++++++ .../core/process/batch/status_checker.py | 33 +------------- 3 files changed, 47 insertions(+), 65 deletions(-) create mode 100644 src/mmirage/core/process/batch/metadata_utils.py diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index d30c277..aebeeac 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -20,40 +20,7 @@ resolve_provider_configs, ) from mmirage.core.process.batch.registry import BatchAdapterFactory - - -def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: - """Return metadata paths as a concrete list. - - Accepts either a single string or a sequence so the CLI and internal callers - can share the same entry point without special-casing file counts. - """ - if isinstance(metadata_paths, str): - return [metadata_paths] - return [str(path) for path in metadata_paths] - - -def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]: - """Load valid JSON objects from one or more receipt files. - - Malformed lines are skipped so partially written or noisy receipt files do - not stop collection. Only JSON objects are retained because downstream - resolution depends on keyed metadata. - """ - records: List[Dict[str, Any]] = [] - for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): - with open(metadata_output_path, "r", encoding="utf-8") as f: - for line in f: - raw = line.strip() - if not raw: - continue - try: - parsed = json.loads(raw) - except json.JSONDecodeError: - continue - if isinstance(parsed, dict): - records.append(parsed) - return records +from mmirage.core.process.batch.metadata_utils import _normalize_metadata_paths, _read_metadata_records def _aggregate_batch_mappings( diff --git a/src/mmirage/core/process/batch/metadata_utils.py b/src/mmirage/core/process/batch/metadata_utils.py new file mode 100644 index 0000000..479bffc --- /dev/null +++ b/src/mmirage/core/process/batch/metadata_utils.py @@ -0,0 +1,44 @@ +"""Shared helpers for batch metadata receipt files.""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Sequence + +logger = logging.getLogger(__name__) + + +def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: + """Return metadata paths as a concrete list.""" + if isinstance(metadata_paths, str): + return [metadata_paths] + return list(metadata_paths) + + +def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]: + """Load valid JSON objects from one or more receipt files. + + Malformed lines are skipped with a warning so partially written or noisy + receipt files do not stop collection. Only JSON objects are retained + because downstream resolution depends on keyed metadata. + """ + records: List[Dict[str, Any]] = [] + for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + logger.warning( + "Skipping malformed metadata JSON line in %s: %s", + metadata_output_path, + exc, + ) + continue + #if isinstance(parsed, dict): + records.append(parsed) + return records \ No newline at end of file diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index 3ca0cdd..66cc488 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -7,51 +7,22 @@ from __future__ import annotations import argparse -import json +import logging import sys from typing import Any, Dict, List, Mapping, Sequence, TextIO, Tuple from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config +from mmirage.core.process.batch.metadata_utils import _normalize_metadata_paths, _read_metadata_records from mmirage.core.process.batch.provider_resolution import ( build_all_provider_configs, resolve_provider_configs, ) from mmirage.core.process.batch.registry import BatchAdapterFactory - -import logging logger = logging.getLogger(__name__) -def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: - if isinstance(metadata_paths, str): - return [metadata_paths] - return [str(path) for path in metadata_paths] - - -def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, str]]: - """Load JSONL metadata records from one or more files. - - Lines that are empty or invalid JSON are ignored to allow best-effort - status checks when receipt files are partially corrupted. - """ - records: List[Dict[str, str]] = [] - for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): - with open(metadata_output_path, "r", encoding="utf-8") as f: - for line in f: - raw = line.strip() - if not raw: - continue - try: - record = json.loads(raw) - except json.JSONDecodeError: - continue - if isinstance(record, dict): - records.append(record) - return records - - def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any]]) -> List[Tuple[str, str]]: """Return unique ``(provider, provider_batch_id)`` pairs. From b9ff47cee4c22261f5e551d632eaa2e628c82157 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 01:26:16 +0200 Subject: [PATCH 34/45] new class BatchMetadataRecord instead of using dict --- src/mmirage/core/process/batch/collector.py | 27 +++++----- .../core/process/batch/metadata_utils.py | 50 ++++++++++++++++--- .../core/process/batch/provider_resolution.py | 7 +-- .../core/process/batch/status_checker.py | 27 +++++----- tests/test_batch_collector.py | 7 +-- tests/test_batch_status_checker.py | 4 +- 6 files changed, 80 insertions(+), 42 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index aebeeac..f972f9c 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -15,16 +15,20 @@ from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config +from mmirage.core.process.batch.metadata_utils import ( + BatchMetadataRecord, + _normalize_metadata_paths, + _read_metadata_records, +) from mmirage.core.process.batch.provider_resolution import ( build_all_provider_configs, resolve_provider_configs, ) from mmirage.core.process.batch.registry import BatchAdapterFactory -from mmirage.core.process.batch.metadata_utils import _normalize_metadata_paths, _read_metadata_records def _aggregate_batch_mappings( - records: Sequence[Mapping[str, Any]], + records: Sequence[BatchMetadataRecord], ) -> Dict[Tuple[str, str], Dict[str, int]]: """Group source-index mappings by provider and provider batch ID. @@ -34,28 +38,21 @@ def _aggregate_batch_mappings( aggregated: Dict[Tuple[str, str], Dict[str, int]] = {} for record in records: - provider = str(record.get("provider", "")).strip().lower() - provider_batch_id = str(record.get("provider_batch_id", "")).strip() - mapping = record.get("custom_id_to_source_index", {}) - - if not provider or not provider_batch_id or not isinstance(mapping, dict): - continue + provider = record.provider + provider_batch_id = record.provider_batch_id + mapping = record.custom_id_to_source_index key = (provider, provider_batch_id) - if key not in aggregated: - aggregated[key] = {} + aggregated.setdefault(key, {}) for custom_id, source_index in mapping.items(): - try: - aggregated[key][str(custom_id)] = int(source_index) - except (TypeError, ValueError): - continue + aggregated[key][str(custom_id)] = source_index return aggregated def collect_and_merge( - records: Sequence[Mapping[str, Any]], + records: Sequence[BatchMetadataRecord], provider_configs: Mapping[str, BatchProviderConfig], output_path: str, ) -> List[Dict[str, Any]]: diff --git a/src/mmirage/core/process/batch/metadata_utils.py b/src/mmirage/core/process/batch/metadata_utils.py index 479bffc..a582bdc 100644 --- a/src/mmirage/core/process/batch/metadata_utils.py +++ b/src/mmirage/core/process/batch/metadata_utils.py @@ -4,11 +4,41 @@ import json import logging -from typing import Any, Dict, List, Sequence +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Sequence logger = logging.getLogger(__name__) +@dataclass(slots=True) +class BatchMetadataRecord: + """Typed batch receipt row shared by collector and status checker.""" + + provider: str + provider_batch_id: str + custom_id_to_source_index: Dict[str, int] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, payload: Mapping[str, Any]) -> "BatchMetadataRecord": + provider = str(payload.get("provider", "")).strip().lower() + provider_batch_id = str(payload.get("provider_batch_id", "")).strip() + + raw_mapping = payload.get("custom_id_to_source_index", {}) + custom_id_to_source_index: Dict[str, int] = {} + if isinstance(raw_mapping, dict): + for custom_id, source_index in raw_mapping.items(): + try: + custom_id_to_source_index[str(custom_id)] = int(source_index) + except (TypeError, ValueError): + continue + + return cls( + provider=provider, + provider_batch_id=provider_batch_id, + custom_id_to_source_index=custom_id_to_source_index, + ) + + def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: """Return metadata paths as a concrete list.""" if isinstance(metadata_paths, str): @@ -16,14 +46,16 @@ def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: return list(metadata_paths) -def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]: +def _read_metadata_records( + metadata_output_paths: str | Sequence[str], +) -> List[BatchMetadataRecord]: """Load valid JSON objects from one or more receipt files. Malformed lines are skipped with a warning so partially written or noisy - receipt files do not stop collection. Only JSON objects are retained - because downstream resolution depends on keyed metadata. + receipt files do not stop collection. Only JSON objects are retained and + converted into typed records with required provider identifiers. """ - records: List[Dict[str, Any]] = [] + records: List[BatchMetadataRecord] = [] for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): with open(metadata_output_path, "r", encoding="utf-8") as f: for line in f: @@ -39,6 +71,10 @@ def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[D exc, ) continue - #if isinstance(parsed, dict): - records.append(parsed) + # defensive check to ensure only dicts are included (useful against partial/corrupt metadata) + if isinstance(parsed, dict): + record = BatchMetadataRecord.from_mapping(parsed) + if not record.provider or not record.provider_batch_id: + continue + records.append(record) return records \ No newline at end of file diff --git a/src/mmirage/core/process/batch/provider_resolution.py b/src/mmirage/core/process/batch/provider_resolution.py index ed66886..5bfb49c 100644 --- a/src/mmirage/core/process/batch/provider_resolution.py +++ b/src/mmirage/core/process/batch/provider_resolution.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Type from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.metadata_utils import BatchMetadataRecord if TYPE_CHECKING: from mmirage.config.config import MMirageConfig @@ -65,11 +66,11 @@ def get_config_cls( ) -def _discover_required_providers(metadata_records: Sequence[Mapping[str, Any]]) -> List[str]: +def _discover_required_providers(metadata_records: Sequence[BatchMetadataRecord]) -> List[str]: providers: List[str] = [] seen = set() for record in metadata_records: - provider = str(record.get("provider", "")).strip().lower() + provider = record.provider if not provider or provider in seen: continue seen.add(provider) @@ -162,7 +163,7 @@ def build_all_provider_configs(cfg: "MMirageConfig") -> Dict[str, BatchProviderC def resolve_provider_configs( - metadata_records: Sequence[Mapping[str, Any]], + metadata_records: Sequence[BatchMetadataRecord], cfg: "MMirageConfig", ) -> Dict[str, BatchProviderConfig]: """Resolve provider configs required by receiver metadata. diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py index 66cc488..7ddbd71 100644 --- a/src/mmirage/core/process/batch/status_checker.py +++ b/src/mmirage/core/process/batch/status_checker.py @@ -14,7 +14,11 @@ from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.batch.adapter import BatchSubmissionResult from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config -from mmirage.core.process.batch.metadata_utils import _normalize_metadata_paths, _read_metadata_records +from mmirage.core.process.batch.metadata_utils import ( + BatchMetadataRecord, + _normalize_metadata_paths, + _read_metadata_records, +) from mmirage.core.process.batch.provider_resolution import ( build_all_provider_configs, resolve_provider_configs, @@ -23,7 +27,9 @@ logger = logging.getLogger(__name__) -def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any]]) -> List[Tuple[str, str]]: +def extract_unique_provider_batches( + metadata_records: Sequence[BatchMetadataRecord], +) -> List[Tuple[str, str]]: """Return unique ``(provider, provider_batch_id)`` pairs. Normalizes provider names to lowercase and ignores records that do not @@ -33,11 +39,8 @@ def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any] seen = set() for record in metadata_records: - provider = str(record.get("provider", "")).strip().lower() - provider_batch_id = str(record.get("provider_batch_id", "")).strip() - - if not provider or not provider_batch_id: - continue + provider = record.provider + provider_batch_id = record.provider_batch_id pair = (provider, provider_batch_id) if pair in seen: @@ -49,7 +52,7 @@ def extract_unique_provider_batches(metadata_records: Sequence[Mapping[str, Any] def run_status_checker( - metadata_records: Sequence[Mapping[str, Any]], + metadata_records: Sequence[BatchMetadataRecord], provider_configs: Mapping[str, BatchProviderConfig], ) -> List[BatchSubmissionResult]: """Check batch status for each referenced provider batch. @@ -63,7 +66,7 @@ def run_status_checker( for provider, provider_batch_id in extract_unique_provider_batches(metadata_records): if provider not in provider_configs: - logger.info(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.") + logger.warning(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.") provider_counts = counter.setdefault(provider, {}) provider_counts["skipped"] = provider_counts.get("skipped", 0) + 1 @@ -77,12 +80,12 @@ def run_status_checker( provider_counts = counter.setdefault(provider, {}) provider_counts[result.status] = provider_counts.get(result.status, 0) + 1 - logger.info("------------ Batch status summary ------------") + print("\n------------ Batch status summary ------------") for provider, status_counts in counter.items(): total = sum(status_counts.values()) - logger.info(f"Provider '{provider}' (Total: {total}):") + print(f"Provider '{provider}' (Total: {total}):") for status, count in status_counts.items(): - logger.info(f" - {status}: {count}/{total}") + print(f" - {status}: {count}/{total}") return results diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index efac86a..cdbb403 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -5,7 +5,7 @@ from mmirage.config.openai_batch import OpenAIBatchConfig -def test_collect_and_merge_reconstructs_rows_deterministically(tmp_path, monkeypatch): +def test_collect_and_merge_reconstructs_rows_deterministically(tmp_path, monkeypatch, caplog): from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge metadata_path = tmp_path / "receipts.jsonl" @@ -73,6 +73,7 @@ def retrieve_results(self, provider_batch_id, config): provider_configs = {"openai": OpenAIBatchConfig(credentials={"api_key": "k"})} records = _read_metadata_records(str(metadata_path)) + assert "Skipping malformed metadata JSON line" in caplog.text rows = collect_and_merge( records=records, provider_configs=provider_configs, @@ -214,7 +215,7 @@ def _fake_collect_and_merge(records, provider_configs, output_path_arg): assert rc == 0 assert len(captured["records"]) == 1 - assert captured["records"][0]["provider"] == "openai" + assert captured["records"][0].provider == "openai" assert "openai" in captured["provider_configs"] assert captured["output_path"] == str(output_path) @@ -277,7 +278,7 @@ def _fake_collect_and_merge(records, provider_configs, output_path_arg): assert rc == 0 assert len(captured["records"]) == 1 - assert captured["records"][0]["provider"] == "openai" + assert captured["records"][0].provider == "openai" assert "openai" in captured["provider_configs"] assert captured["output_path"] == str(output_path) diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py index 3e601dd..f661f77 100644 --- a/tests/test_batch_status_checker.py +++ b/tests/test_batch_status_checker.py @@ -129,7 +129,7 @@ def _fake_run_status_checker(metadata_records, provider_configs, output=None): assert rc == 0 assert len(called["metadata_records"]) == 1 - assert called["metadata_records"][0]["provider"] == "openai" + assert called["metadata_records"][0].provider == "openai" assert "openai" in called["provider_configs"] @@ -241,7 +241,7 @@ def _fake_run_status_checker(metadata_records, provider_configs, output=None): assert rc == 0 assert len(called["metadata_records"]) == 1 - assert called["metadata_records"][0]["provider"] == "openai" + assert called["metadata_records"][0].provider == "openai" assert "openai" in called["provider_configs"] From bfad50e6f917a6845e13bd93654bd49c882bf3ac Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 02:52:30 +0200 Subject: [PATCH 35/45] log, tracking and syntax modif following the comments --- src/mmirage/core/process/batch/collector.py | 75 ++++++----- tests/test_batch_collector.py | 132 ++++++++++++++------ 2 files changed, 142 insertions(+), 65 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index f972f9c..d9c7775 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -9,6 +9,7 @@ import argparse import json +import logging import os import sys from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple @@ -26,6 +27,8 @@ ) from mmirage.core.process.batch.registry import BatchAdapterFactory +logger = logging.getLogger(__name__) + def _aggregate_batch_mappings( records: Sequence[BatchMetadataRecord], @@ -42,6 +45,9 @@ def _aggregate_batch_mappings( provider_batch_id = record.provider_batch_id mapping = record.custom_id_to_source_index + if not provider or not provider_batch_id or not mapping: + continue + key = (provider, provider_batch_id) aggregated.setdefault(key, {}) @@ -95,14 +101,18 @@ def collect_and_merge( custom_id = str(result_row.get("custom_id", "")).strip() if not custom_id or custom_id not in mapping: continue - row_payload = _build_output_payload(result_row) + row_payload = _build_output_payload(result_row, custom_id=custom_id) indexed_rows[custom_id] = { - "source_index": int(mapping.get(custom_id, 0)), + "source_index": int(mapping[custom_id]), "custom_id": custom_id, **row_payload, } - ordered_rows = sorted(indexed_rows.values(), key=lambda row: row.get("source_index", 0)) + # Sort primarily by source_index and secondarily by custom_id to ensure + # deterministic ordering when multiple rows share the same source_index. + ordered_rows = sorted( + indexed_rows.values(), key=lambda row: (row.get("source_index", 0), row.get("custom_id", "")) + ) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: @@ -112,7 +122,7 @@ def collect_and_merge( return ordered_rows -def _build_output_payload(result_row: Mapping[str, Any]) -> Dict[str, Any]: +def _build_output_payload(result_row: Mapping[str, Any], custom_id: str = "") -> Dict[str, Any]: """Convert provider content into the receiver's output schema. The collector preserves raw text for opaque generations, but maps structured @@ -126,6 +136,10 @@ def _build_output_payload(result_row: Mapping[str, Any]) -> Dict[str, Any]: try: parsed = json.loads(raw_content) except json.JSONDecodeError: + logger.warning( + f"Failed to parse JSON for result row (custom_id={custom_id}). " + f"Treating as raw text. Content: {raw_content[:100]}" + ) return {"caption": raw_content} if isinstance(parsed, dict) and ("question" in parsed or "answer" in parsed): @@ -190,37 +204,38 @@ def main(argv: Sequence[str] | None = None) -> int: args = _build_arg_parser().parse_args(argv) from mmirage.config.utils import load_mmirage_config - cfg = load_mmirage_config(args.config) - if args.metadata_path: - metadata_paths = args.metadata_path - else: - all_provider_configs = build_all_provider_configs(cfg) - metadata_paths = [ - config.metadata_output_path - for config in all_provider_configs.values() - if config.metadata_output_path - ] - metadata_paths = list(dict.fromkeys(metadata_paths)) + try: + cfg = load_mmirage_config(args.config) + if args.metadata_path: + metadata_paths = args.metadata_path + else: + all_provider_configs = build_all_provider_configs(cfg) + metadata_paths = [ + config.metadata_output_path + for config in all_provider_configs.values() + if config.metadata_output_path + ] + metadata_paths = list(dict.fromkeys(metadata_paths)) + if not metadata_paths: + raise ValueError( + "No metadata paths provided and none found in config batch_provider blocks." + ) + metadata_paths = resolve_metadata_paths_from_config(metadata_paths) + if not metadata_paths: - raise ValueError( - "No metadata paths provided and none found in config batch_provider blocks." - ) - metadata_paths = resolve_metadata_paths_from_config(metadata_paths) + raise ValueError("No metadata paths provided and none found in config batch_provider blocks.") - if not metadata_paths: - raise ValueError("No metadata paths provided and none found in config batch_provider blocks.") + records = _read_metadata_records(metadata_paths) + provider_configs = resolve_provider_configs(records, cfg) - records = _read_metadata_records(metadata_paths) - provider_configs = resolve_provider_configs(records, cfg) + rows = collect_and_merge(records, provider_configs, args.output_path) + print(f"Merged {len(rows)} rows and saved to {args.output_path}") + except Exception as exc: + logger.exception("Collector failed") + return 1 - rows = collect_and_merge(records, provider_configs, args.output_path) - print(f"Merged {len(rows)} rows and saved to {args.output_path}") return 0 if __name__ == "__main__": - try: - raise SystemExit(main()) - except Exception as exc: - print(f"Collector failed: {exc}", file=sys.stderr) - raise SystemExit(1) + raise SystemExit(main()) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index cdbb403..a7591b0 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -283,7 +283,7 @@ def _fake_collect_and_merge(records, provider_configs, output_path_arg): assert captured["output_path"] == str(output_path) -def test_collector_main_raises_when_config_metadata_paths_missing(tmp_path, monkeypatch): +def test_collector_main_raises_when_config_metadata_paths_missing(tmp_path, monkeypatch, caplog): from mmirage.core.process.batch import collector metadata_base = tmp_path / "batch_metadata.jsonl" @@ -303,20 +303,19 @@ def test_collector_main_raises_when_config_metadata_paths_missing(tmp_path, monk ) monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) - with pytest.raises( - ValueError, match="No metadata receipts matched config metadata_output_path patterns" - ): - collector.main( - [ - "--output-path", - str(output_path), - "--config", - str(config_path), - ] - ) + rc = collector.main( + [ + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "No metadata receipts matched config metadata_output_path patterns" in caplog.text -def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch): +def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch, caplog): from mmirage.core.process.batch import collector metadata_path = tmp_path / "receipts.jsonl" @@ -339,17 +338,18 @@ def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) - with pytest.raises(ValueError, match="missing from YAML batch_provider config"): - collector.main( - [ - "--metadata-path", - str(metadata_path), - "--output-path", - str(output_path), - "--config", - str(config_path), - ] - ) + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "missing from YAML batch_provider config" in caplog.text def test_collect_and_merge_routes_multiple_providers(tmp_path, monkeypatch): @@ -431,7 +431,7 @@ def retrieve_results(self, provider_batch_id, config): assert ("batch_unit", "unit") in adapters["unit"].calls -def test_collector_main_raises_for_invalid_batch_provider_config(tmp_path, monkeypatch): +def test_collector_main_raises_for_invalid_batch_provider_config(tmp_path, monkeypatch, caplog): from mmirage.core.process.batch import collector metadata_path = tmp_path / "receipts.jsonl" @@ -457,14 +457,76 @@ def test_collector_main_raises_for_invalid_batch_provider_config(tmp_path, monke ) monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) - with pytest.raises(ValueError, match="batch_endpoint must start with '/'"): - collector.main( - [ - "--metadata-path", - str(metadata_path), - "--output-path", - str(output_path), - "--config", - str(config_path), - ] + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "batch_endpoint must start with '/'" in caplog.text + + +def test_collect_and_merge_tiebreaker_secondary_sort_key(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_tie", + "custom_id_to_source_index": {"a": 0, "b": 0, "c": 1}, + } ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_tie.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + # Return rows intentionally out-of-order to ensure collector sorts + # deterministically using the secondary key. + return [ + {"custom_id": "b", "generated_text": "B"}, + {"custom_id": "a", "generated_text": "A"}, + {"custom_id": "c", "generated_text": "C"}, + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert [r["custom_id"] for r in rows] == ["a", "b", "c"] + + +def test_build_output_payload_logs_malformed_json(caplog): + from mmirage.core.process.batch.collector import _build_output_payload + + malformed_json = '{"question": "incomplete' + result_row = { + "custom_id": "row_123", + "generated_text": malformed_json, + } + + with caplog.at_level("WARNING"): + output = _build_output_payload(result_row, custom_id="row_123") + + assert output == {"caption": malformed_json} + assert "Failed to parse JSON for result row" in caplog.text + assert "custom_id=row_123" in caplog.text + assert "Treating as raw text" in caplog.text From 41d59f22b54e08a73f5b2601ba3aa97ca566219a Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 03:16:53 +0200 Subject: [PATCH 36/45] small modif following copilot's comments --- src/mmirage/core/process/batch/adapter.py | 4 + src/mmirage/core/process/batch/collector.py | 4 +- .../core/process/batch/openai_adapter.py | 6 + tests/test_batch_collector.py | 140 ++++++++++++++++++ 4 files changed, 152 insertions(+), 2 deletions(-) diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py index ff0126e..aa5894a 100644 --- a/src/mmirage/core/process/batch/adapter.py +++ b/src/mmirage/core/process/batch/adapter.py @@ -137,6 +137,10 @@ def retrieve_results( ) -> Sequence[Dict[str, Any]]: """Download and parse completed batch results from the provider. + Implementations should normalize each returned row into a plain mapping + and, when a text payload is available, expose it as ``generated_text`` + so downstream collectors can consume a provider-agnostic result shape. + Args: provider_batch_id: Provider-side batch/job identifier. config: Provider configuration containing credentials and endpoint diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index d9c7775..fa9682e 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -94,7 +94,7 @@ def collect_and_merge( config=provider_configs[provider], ) - indexed_rows: MutableMapping[str, Dict[str, Any]] = {} + indexed_rows: MutableMapping[Tuple[str, str, str], Dict[str, Any]] = {} for pair, mapping in pair_to_mapping.items(): results = pair_to_results.get(pair, []) for result_row in results: @@ -102,7 +102,7 @@ def collect_and_merge( if not custom_id or custom_id not in mapping: continue row_payload = _build_output_payload(result_row, custom_id=custom_id) - indexed_rows[custom_id] = { + indexed_rows[(pair[0], pair[1], custom_id)] = { "source_index": int(mapping[custom_id]), "custom_id": custom_id, **row_payload, diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index ea7dd13..0f6cb67 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -163,6 +163,12 @@ def retrieve_results( provider_batch_id: str, config: BatchProviderConfig, ) -> Sequence[Dict[str, Any]]: + """Download completed OpenAI batch rows and normalize text into ``generated_text``. + + OpenAI batch outputs can surface the assistant payload in nested + response bodies, so this method flattens the provider-specific shape + before returning rows to the provider-agnostic collector. + """ openai_config = self._as_openai_config(config) client = self._create_client(openai_config) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index a7591b0..f8e16cd 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -167,6 +167,146 @@ def retrieve_results(self, provider_batch_id, config): ] +def test_collect_and_merge_keeps_rows_with_duplicate_custom_ids_across_batches( + tmp_path, monkeypatch +): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"shared": 0}, + } + ), + json.dumps( + { + "provider": "unit", + "provider_batch_id": "batch_unit", + "custom_id_to_source_index": {"shared": 1}, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_duplicates.jsonl" + + class OpenAIAdapter: + def retrieve_results(self, provider_batch_id, config): + return [{"custom_id": "shared", "generated_text": "openai"}] + + class UnitAdapter: + def retrieve_results(self, provider_batch_id, config): + return [{"custom_id": "shared", "generated_text": "unit"}] + + adapters = { + "openai": OpenAIAdapter(), + "unit": UnitAdapter(), + } + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: adapters[config.provider], + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={ + "openai": SimpleNamespace(provider="openai"), + "unit": SimpleNamespace(provider="unit"), + }, + output_path=str(output_path), + ) + + assert [row["source_index"] for row in rows] == [0, 1] + assert [row["custom_id"] for row in rows] == ["shared", "shared"] + assert [row["caption"] for row in rows] == ["openai", "unit"] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert written == rows + + +def test_collect_and_merge_uses_openai_adapter_generated_text(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"o1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_openai.jsonl" + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_1" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"custom_id":"o1","response":{"body":{"choices":[{"message":{"content":"{\\"question\\":\\"Q\\",\\"answer\\":\\"A\\"}"}}]}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: OpenAIBatchAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert rows == [ + { + "source_index": 0, + "custom_id": "o1", + "conversations": [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "A"}, + ], + } + ] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert written == rows + + def test_collector_main_uses_config_and_records(tmp_path, monkeypatch): from mmirage.core.process.batch import collector From eaf089c91926a29f1167b15bf3b416215e734a32 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 05:02:37 +0200 Subject: [PATCH 37/45] syntax improved according to comments --- src/mmirage/config/batch_provider.py | 21 ++++++++-- src/mmirage/core/process/batch/chunking.py | 49 ++++++++++------------ 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/mmirage/config/batch_provider.py b/src/mmirage/config/batch_provider.py index 12e0305..deb7692 100644 --- a/src/mmirage/config/batch_provider.py +++ b/src/mmirage/config/batch_provider.py @@ -4,10 +4,18 @@ submission provider (OpenAI, Anthropic, etc.). """ +from enum import Enum from dataclasses import dataclass, field from typing import Any, Dict, Literal, Optional +class OversizedRequestPolicy(str, Enum): + """Policy for handling single requests that exceed the chunk byte limit.""" + + ISOLATE = "isolate" + REJECT = "reject" + + @dataclass class BatchRetryPolicy: """Retry behavior used by provider-neutral batch submission orchestration. @@ -63,7 +71,7 @@ class BatchProviderConfig: max_requests_per_chunk: Optional[int] = None metadata_output_path: str = "" retry_policy: BatchRetryPolicy = field(default_factory=BatchRetryPolicy) - oversized_request_policy: Literal["isolate", "reject"] = "isolate" + oversized_request_policy: OversizedRequestPolicy | str = OversizedRequestPolicy.ISOLATE extras: Dict[str, Any] = field(default_factory=dict) credentials: Dict[str, str] = field(default_factory=dict) @@ -76,5 +84,12 @@ def __post_init__(self) -> None: raise ValueError("max_chunk_bytes must be >= 1") if self.max_requests_per_chunk is not None and self.max_requests_per_chunk < 1: raise ValueError("max_requests_per_chunk must be >= 1 when provided") - if self.oversized_request_policy not in {"isolate", "reject"}: - raise ValueError("oversized_request_policy must be either 'isolate' or 'reject'") + if isinstance(self.oversized_request_policy, str): + try: + self.oversized_request_policy = OversizedRequestPolicy( + self.oversized_request_policy.strip().lower() + ) + except ValueError as exc: + raise ValueError( + "oversized_request_policy must be either 'isolate' or 'reject'" + ) from exc diff --git a/src/mmirage/core/process/batch/chunking.py b/src/mmirage/core/process/batch/chunking.py index b490260..b4c6298 100644 --- a/src/mmirage/core/process/batch/chunking.py +++ b/src/mmirage/core/process/batch/chunking.py @@ -2,9 +2,9 @@ import logging from dataclasses import dataclass -from typing import Any, List, Mapping, Sequence +from typing import Any, Dict, List, Mapping, Sequence -from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.batch_provider import BatchProviderConfig, OversizedRequestPolicy from mmirage.core.process.batch.adapter import BatchSubmissionAdapter logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ class RequestChunk: """Chunk of provider-ready requests with aggregate metadata.""" - requests: List[Mapping[str, Any]] + requests: List[Dict[str, Any]] total_bytes: int has_oversized_request: bool = False @@ -30,19 +30,28 @@ def __init__(self, adapter: BatchSubmissionAdapter, config: BatchProviderConfig) self.adapter = adapter self.config = config - def chunk_requests(self, requests: Sequence[Mapping[str, Any]]) -> List[RequestChunk]: + def chunk_requests(self, requests: Sequence[Dict[str, Any]]) -> List[RequestChunk]: """Chunk requests according to max bytes, max requests, and oversize policy.""" chunks: List[RequestChunk] = [] - current_requests: List[Mapping[str, Any]] = [] + current_requests: List[Dict[str, Any]] = [] current_total_bytes = 0 max_chunk_bytes = self.config.max_chunk_bytes + def append_current_chunk() -> None: + if current_requests: + chunks.append( + RequestChunk( + requests=list(current_requests), + total_bytes=current_total_bytes, + ) + ) + for request in requests: request_size = self.adapter.estimate_request_bytes(request) if request_size > max_chunk_bytes: - if self.config.oversized_request_policy == "reject": + if self.config.oversized_request_policy is OversizedRequestPolicy.REJECT: raise ValueError( "Encountered oversized request: " f"{request_size} bytes exceeds max_chunk_bytes={max_chunk_bytes}" @@ -54,15 +63,9 @@ def chunk_requests(self, requests: Sequence[Mapping[str, Any]]) -> List[RequestC max_chunk_bytes, ) - if current_requests: - chunks.append( - RequestChunk( - requests=list(current_requests), - total_bytes=current_total_bytes, - ) - ) - current_requests = [] - current_total_bytes = 0 + append_current_chunk() + current_requests = [] + current_total_bytes = 0 chunks.append( RequestChunk( @@ -77,12 +80,7 @@ def chunk_requests(self, requests: Sequence[Mapping[str, Any]]) -> List[RequestC would_exceed_count = self._would_exceed_count_limit(current_requests) if current_requests and (would_exceed_bytes or would_exceed_count): - chunks.append( - RequestChunk( - requests=list(current_requests), - total_bytes=current_total_bytes, - ) - ) + append_current_chunk() current_requests = [] current_total_bytes = 0 @@ -90,16 +88,11 @@ def chunk_requests(self, requests: Sequence[Mapping[str, Any]]) -> List[RequestC current_total_bytes += request_size if current_requests: - chunks.append( - RequestChunk( - requests=list(current_requests), - total_bytes=current_total_bytes, - ) - ) + append_current_chunk() return chunks - def _would_exceed_count_limit(self, current_requests: Sequence[Mapping[str, Any]]) -> bool: + def _would_exceed_count_limit(self, current_requests: Sequence[Dict[str, Any]]) -> bool: if self.config.max_requests_per_chunk is None: return False return len(current_requests) >= self.config.max_requests_per_chunk From 91be3deb58560dd52b3e98352d715661e45723be Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 05:42:29 +0200 Subject: [PATCH 38/45] correction of use of a dict instead of the config class --- src/mmirage/config/utils.py | 6 ++++++ .../core/process/batch/provider_resolution.py | 11 +++++++++- .../core/process/processors/llm/config.py | 3 ++- .../process/processors/llm/llm_processor.py | 16 +++++---------- tests/test_batch_orchestrator.py | 12 +++++------ tests/test_integration_batch_pipeline.py | 20 +++++++++---------- 6 files changed, 39 insertions(+), 29 deletions(-) diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index 77e8bcb..8fda777 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -5,8 +5,10 @@ import yaml import os +from mmirage.config.batch_provider import BatchProviderConfig from mmirage.config.config import MMirageConfig from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar +from mmirage.core.process.batch.provider_resolution import resolve_single_provider_config from mmirage.core.loader.base import BaseDataLoaderConfig, DataLoaderRegistry EnvValue: TypeAlias = Union[str, List["EnvValue"], Dict[str, "EnvValue"]] @@ -102,12 +104,16 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: clz = ProcessorRegistry.get_output_var_cls(data["type"]) return from_dict(clz, data, config=config) + def batch_provider_hook(data: Dict[str, Any]) -> BatchProviderConfig: + return resolve_single_provider_config(data) + cfg = expand_env_vars(cfg) config = Config( type_hooks={ BaseProcessorConfig: processor_config_hook, BaseDataLoaderConfig: loader_config_hook, OutputVar: output_var_hook, + BatchProviderConfig: batch_provider_hook, } ) cfg_obj = from_dict(MMirageConfig, cast(dict, cfg), config=config) diff --git a/src/mmirage/core/process/batch/provider_resolution.py b/src/mmirage/core/process/batch/provider_resolution.py index 5bfb49c..d28042e 100644 --- a/src/mmirage/core/process/batch/provider_resolution.py +++ b/src/mmirage/core/process/batch/provider_resolution.py @@ -6,6 +6,7 @@ from __future__ import annotations +from dataclasses import asdict from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Type from mmirage.config.batch_provider import BatchProviderConfig @@ -86,7 +87,15 @@ def _extract_batch_provider_blocks(cfg: MMirageConfig) -> Dict[str, Dict[str, An """ provider_blocks: Dict[str, Dict[str, Any]] = {} for processor_cfg in cfg.processors: - raw_block = dict(getattr(processor_cfg, "batch_provider", {}) or {}) + raw_provider = getattr(processor_cfg, "batch_provider", None) + if raw_provider is None: + continue + + if isinstance(raw_provider, BatchProviderConfig): + raw_block = asdict(raw_provider) + else: + raw_block = dict(raw_provider or {}) + if not raw_block: continue diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index 3a1efbf..0a96670 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -7,6 +7,7 @@ from typing import Dict, Optional, Sequence, Type, Any, List from pydantic import BaseModel, create_model +from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.variables import BaseVar, OutputVar from mmirage.core.process.base import BaseProcessorConfig @@ -82,7 +83,7 @@ class SGLangLLMConfig(BaseProcessorConfig): server_args: SGLangServerArgs = field(default_factory=SGLangServerArgs) default_sampling_params: Dict[str, Any] = field(default_factory=dict) chat_template: str = "" # Empty means use tokenizer's default - batch_provider: Dict[str, Any] = field(default_factory=dict) + batch_provider: Optional[BatchProviderConfig] = None @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 0a29600..365c318 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -14,7 +14,6 @@ from mmirage.core.process.base import BaseProcessor, ProcessorRegistry from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator -from mmirage.core.process.batch.provider_resolution import resolve_single_provider_config from mmirage.core.process.batch.registry import BatchAdapterFactory from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment @@ -63,12 +62,8 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: """ super().__init__(engine_args, **kwargs) - batch_provider_attrs = getattr(engine_args, "batch_provider", None) - if batch_provider_attrs is None: - is_provider_batch_enabled = False - else: - provider_cfg_raw = dict(batch_provider_attrs) - is_provider_batch_enabled = bool(provider_cfg_raw.get("enabled", True)) + batch_provider_cfg = getattr(engine_args, "batch_provider", None) + is_provider_batch_enabled = bool(batch_provider_cfg and batch_provider_cfg.enabled) # In provider-batch mode we only build payloads/metadata and should not # initialize GPU-backed SGLang runtime. @@ -93,14 +88,13 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: self._setup_batch_runtime() def _setup_batch_runtime(self) -> None: - provider_cfg_raw = dict(getattr(self.config, "batch_provider", {}) or {}) - if not provider_cfg_raw: + provider_cfg = getattr(self.config, "batch_provider", None) + if provider_cfg is None: return - if not provider_cfg_raw.get("enabled", True): + if not provider_cfg.enabled: return - provider_cfg = resolve_single_provider_config(provider_cfg_raw) self._batch_provider_config = provider_cfg self._batch_adapter = BatchAdapterFactory.from_config(provider_cfg) run_id = uuid.uuid4().hex[:6] diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py index 2629e27..3ef6808 100644 --- a/tests/test_batch_orchestrator.py +++ b/tests/test_batch_orchestrator.py @@ -146,11 +146,11 @@ def test_llm_processor_initializes_with_custom_provider(tmp_path): config = SGLangLLMConfig( type="llm", server_args=SGLangServerArgs(model_path="dummy-model"), - batch_provider={ - "provider": "unit", - "unit_setting": "custom", - "metadata_output_path": str(tmp_path / "metadata.jsonl"), - }, + batch_provider=UnitBatchConfig( + provider="unit", + unit_setting="custom", + metadata_output_path=str(tmp_path / "metadata.jsonl"), + ), ) processor_cls = ProcessorRegistry.get_processor("llm") @@ -190,7 +190,7 @@ def apply_chat_template(self, *args, **kwargs): config = SGLangLLMConfig( type="llm", server_args=SGLangServerArgs(model_path="dummy-model"), - batch_provider={"enabled": False}, + batch_provider=BatchProviderConfig(provider="openai", enabled=False), ) processor_cls = ProcessorRegistry.get_processor("llm") diff --git a/tests/test_integration_batch_pipeline.py b/tests/test_integration_batch_pipeline.py index 77c73a0..6928aa7 100644 --- a/tests/test_integration_batch_pipeline.py +++ b/tests/test_integration_batch_pipeline.py @@ -3,6 +3,7 @@ from datasets import load_dataset +from mmirage.config.openai_batch import OpenAIBatchConfig from mmirage.core.process import LLMProcessor # Ensures processor registration. from mmirage.core.process.mapper import MMIRAGEMapper from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig, SGLangServerArgs @@ -83,16 +84,15 @@ def apply_chat_template(self, user_prompt, tokenize=False, add_generation_prompt llm_cfg = SGLangLLMConfig( type="llm", server_args=SGLangServerArgs(model_path="dummy-model"), - batch_provider={ - "enabled": True, - "provider": "openai", - "model": "gpt-4.1-mini", - "max_chunk_bytes": 500, - "max_requests_per_chunk": None, - "metadata_output_path": str(metadata_base), - "credentials": {"api_key": "test-key"}, - "metadata": {"pipeline": "integration-test"}, - }, + batch_provider=OpenAIBatchConfig( + enabled=True, + model="gpt-4.1-mini", + max_chunk_bytes=500, + max_requests_per_chunk=None, + metadata_output_path=str(metadata_base), + credentials={"api_key": "test-key"}, + metadata={"pipeline": "integration-test"}, + ), ) mapper = MMIRAGEMapper( From 08a1dfa83f36f05371ee1274f3c60d661f1ed9f4 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 06:47:49 +0200 Subject: [PATCH 39/45] cleaning --- src/mmirage/core/process/batch/adapter.py | 2 -- src/mmirage/core/process/batch/openai_adapter.py | 3 +-- src/mmirage/core/process/batch/orchestrator.py | 10 +++++----- tests/test_batch_adapter_contracts.py | 4 ++-- tests/test_batch_chunking.py | 2 +- tests/test_batch_orchestrator.py | 2 +- tests/test_openai_batch_adapter.py | 2 +- 7 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py index aa5894a..614dfc3 100644 --- a/src/mmirage/core/process/batch/adapter.py +++ b/src/mmirage/core/process/batch/adapter.py @@ -96,13 +96,11 @@ def submit_chunk( def parse_submission_result( self, raw_result: Mapping[str, Any], - request_count: int, ) -> BatchSubmissionResult: """Normalize provider submission output into a shared result model. Args: raw_result: Raw payload returned by ``submit_chunk``. - request_count: Number of requests submitted in the chunk. Returns: A normalized ``BatchSubmissionResult`` for provider-neutral diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 0f6cb67..7e221b5 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -156,7 +156,7 @@ def check_batch_status( "id": self._read_attr(retrieved, "id"), "status": self._read_attr(retrieved, "status"), } - return self.parse_submission_result(raw_result=raw_result, request_count=0) + return self.parse_submission_result(raw_result=raw_result) def retrieve_results( self, @@ -201,7 +201,6 @@ def retrieve_results( def parse_submission_result( self, raw_result: Mapping[str, Any], - request_count: int, ) -> BatchSubmissionResult: batch_id = str(raw_result.get("id") or raw_result.get("batch_id") or "") status = str(raw_result.get("status") or "unknown") diff --git a/src/mmirage/core/process/batch/orchestrator.py b/src/mmirage/core/process/batch/orchestrator.py index 6d1dbfd..6017c49 100644 --- a/src/mmirage/core/process/batch/orchestrator.py +++ b/src/mmirage/core/process/batch/orchestrator.py @@ -17,7 +17,7 @@ @dataclass class _PendingRequest: request: Mapping[str, Any] - source_index: int + source_index: int # original row index of the data sample within the input dataset class BatchSubmissionOrchestrator: @@ -48,8 +48,8 @@ def add_requests( self._pending.append(_PendingRequest(request=request, source_index=source_index)) return self._emit_ready_chunks( - finalize=False, model_params_snapshot=model_params_snapshot, + finalize=False, ) def finalize( @@ -58,14 +58,15 @@ def finalize( ) -> List[BatchSubmissionResult]: """Flush all remaining requests at end-of-dataset lifecycle.""" return self._emit_ready_chunks( - finalize=True, model_params_snapshot=model_params_snapshot, + finalize=True, + ) def _emit_ready_chunks( self, - finalize: bool, model_params_snapshot: Optional[Mapping[str, Any]], + finalize: bool = False, ) -> List[BatchSubmissionResult]: if not self._pending: return [] @@ -99,7 +100,6 @@ def _emit_ready_chunks( ) parsed_result = self.adapter.parse_submission_result( raw_result=raw_result, - request_count=len(chunk_entries), ) submission_results.append(parsed_result) diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py index 70887ab..85890ef 100644 --- a/tests/test_batch_adapter_contracts.py +++ b/tests/test_batch_adapter_contracts.py @@ -37,7 +37,7 @@ def submit_chunk(self, chunk_id, requests, config): "requests": len(requests), } - def parse_submission_result(self, raw_result, request_count): + def parse_submission_result(self, raw_result): return BatchSubmissionResult( provider_batch_id=str(raw_result["batch_id"]), status=str(raw_result["status"]), @@ -113,7 +113,7 @@ def test_complete_adapter_is_interface_compliant(): assert estimated_bytes > 0 raw_result = adapter.submit_chunk(chunk_id="chunk-1", requests=[request], config=config) - parsed = adapter.parse_submission_result(raw_result=raw_result, request_count=1) + parsed = adapter.parse_submission_result(raw_result=raw_result) assert parsed.provider_batch_id == "unit-chunk-1" assert parsed.status == "submitted" diff --git a/tests/test_batch_chunking.py b/tests/test_batch_chunking.py index 6451362..fcceb8c 100644 --- a/tests/test_batch_chunking.py +++ b/tests/test_batch_chunking.py @@ -27,7 +27,7 @@ def estimate_request_bytes(self, request): def submit_chunk(self, chunk_id, requests, config): return {"id": chunk_id, "status": "submitted"} - def parse_submission_result(self, raw_result, request_count): + def parse_submission_result(self, raw_result): return BatchSubmissionResult( provider_batch_id=str(raw_result["id"]), status=str(raw_result["status"]), diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py index 3ef6808..0b16f9b 100644 --- a/tests/test_batch_orchestrator.py +++ b/tests/test_batch_orchestrator.py @@ -30,7 +30,7 @@ def submit_chunk(self, chunk_id, requests, config): ) return {"id": f"batch-{chunk_id}", "status": "submitted"} - def parse_submission_result(self, raw_result, request_count): + def parse_submission_result(self, raw_result): return BatchSubmissionResult( provider_batch_id=str(raw_result["id"]), status=str(raw_result["status"]), diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index 06356d3..a03cc2c 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -194,7 +194,7 @@ def test_openai_parse_submission_result_normalizes_payload(): "input_file_id": "file_123", } - result = adapter.parse_submission_result(raw_result=raw, request_count=4) + result = adapter.parse_submission_result(raw_result=raw) assert isinstance(result, BatchSubmissionResult) assert result.provider_batch_id == "batch_123" From 69d01f16d36b46ebab821127b77babd341124f5c Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 08:08:45 +0200 Subject: [PATCH 40/45] improve of openai adapter following the comments --- .../core/process/batch/openai_adapter.py | 81 ++++++++++++------- .../core/process/batch/orchestrator.py | 1 - 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 7e221b5..ff7a45d 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -8,7 +8,7 @@ import os from typing import Any, Dict, List, Mapping, Sequence -from openai import OpenAI +from openai import AuthenticationError, OpenAI from mmirage.config.batch_provider import BatchProviderConfig from mmirage.config.openai_batch import OpenAIBatchConfig @@ -26,9 +26,17 @@ def build_request( payload: Dict[str, Any], config: BatchProviderConfig, ) -> Mapping[str, Any]: - openai_config = self._as_openai_config(config) + openai_config = self._check_openai_config(config) body = copy.deepcopy(payload) - expected_schema = body.pop("expected_schema", None) + expected_schema = body.get("expected_schema") + if expected_schema is not None and ( + not isinstance(expected_schema, list) + or not all(isinstance(key, str) for key in expected_schema) + ): + raise ValueError( + "expected_schema must be a list of strings, " + f"got {type(expected_schema).__name__}" + ) body.setdefault("model", openai_config.model) self._convert_local_images_to_data_uris(body) @@ -56,7 +64,7 @@ def build_request( } @staticmethod - def _convert_local_images_to_data_uris(body: Mapping[str, Any]) -> None: + def _convert_local_images_to_data_uris(body: Dict[str, Any]) -> None: messages = body.get("messages") if not isinstance(messages, list): return @@ -70,9 +78,7 @@ def _convert_local_images_to_data_uris(body: Mapping[str, Any]) -> None: continue for part in content: - if not isinstance(part, dict): - continue - if part.get("type") != "image_url": + if not isinstance(part, dict) or part.get("type") != "image_url": continue image_url = part.get("image_url") @@ -110,9 +116,8 @@ def submit_chunk( chunk_id: str, requests: Sequence[Mapping[str, Any]], config: BatchProviderConfig, - ) -> Mapping[str, Any]: - openai_config = self._as_openai_config(config) - + ) -> Dict[str, Any]: + openai_config = self._check_openai_config(config) client = self._create_client(openai_config) jsonl_lines = [ @@ -148,7 +153,7 @@ def check_batch_status( provider_batch_id: str, config: BatchProviderConfig, ) -> BatchSubmissionResult: - openai_config = self._as_openai_config(config) + openai_config = self._check_openai_config(config) client = self._create_client(openai_config) retrieved = client.batches.retrieve(provider_batch_id) @@ -169,7 +174,7 @@ def retrieve_results( response bodies, so this method flattens the provider-specific shape before returning rows to the provider-agnostic collector. """ - openai_config = self._as_openai_config(config) + openai_config = self._check_openai_config(config) client = self._create_client(openai_config) retrieved = client.batches.retrieve(provider_batch_id) @@ -203,16 +208,21 @@ def parse_submission_result( raw_result: Mapping[str, Any], ) -> BatchSubmissionResult: batch_id = str(raw_result.get("id") or raw_result.get("batch_id") or "") - status = str(raw_result.get("status") or "unknown") + status = raw_result.get("status", "unknown") return BatchSubmissionResult( provider_batch_id=batch_id, status=status, - raw_response=dict(raw_result), + raw_response=raw_result, ) @staticmethod - def _as_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: + def _check_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: + """Validate that `config` is an `OpenAIBatchConfig` and return it. + + Raises `TypeError` when the provided `config` is not an + `OpenAIBatchConfig`. + """ if isinstance(config, OpenAIBatchConfig): return config raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") @@ -228,17 +238,21 @@ def _extract_generated_text(row: Mapping[str, Any]) -> str: return "" choices = body.get("choices") - if isinstance(choices, list) and choices: - first = choices[0] - if isinstance(first, Mapping): - message = first.get("message") - if isinstance(message, Mapping): - content = message.get("content") - if isinstance(content, str): - return content - text = first.get("text") - if isinstance(text, str): - return text + if not isinstance(choices, list) or not choices: + text = body.get("text") + return text if isinstance(text, str) else "" + + first = choices[0] + if isinstance(first, Mapping): + message = first.get("message") + if isinstance(message, Mapping): + content = message.get("content") + if isinstance(content, str): + return content + + text = first.get("text") + if isinstance(text, str): + return text text = body.get("text") if isinstance(text, str): @@ -248,7 +262,7 @@ def _extract_generated_text(row: Mapping[str, Any]) -> str: @staticmethod def _create_client(config: OpenAIBatchConfig) -> OpenAI: - api_key = (config.credentials.get("api_key", "") or "").strip() + api_key = config.credentials.get("api_key", "").strip() if not api_key: api_key = os.environ.get("OPENAI_API_KEY", "").strip() @@ -257,10 +271,15 @@ def _create_client(config: OpenAIBatchConfig) -> OpenAI: "OpenAI API key is missing. Provide credentials.api_key or set OPENAI_API_KEY." ) - client_kwargs = {"api_key": api_key} - if config.base_url: - client_kwargs["base_url"] = config.base_url - return OpenAI(**client_kwargs) + try: + client_kwargs = {"api_key": api_key} + if config.base_url: + client_kwargs["base_url"] = config.base_url + return OpenAI(**client_kwargs) + except AuthenticationError as exc: + raise ValueError(f"OpenAI authentication failed: {exc}") from exc + except Exception as exc: + raise ValueError(f"Failed to create OpenAI client: {exc}") from exc @staticmethod def _extract_content_text(content_response: Any) -> str: diff --git a/src/mmirage/core/process/batch/orchestrator.py b/src/mmirage/core/process/batch/orchestrator.py index 6017c49..60ad6d4 100644 --- a/src/mmirage/core/process/batch/orchestrator.py +++ b/src/mmirage/core/process/batch/orchestrator.py @@ -60,7 +60,6 @@ def finalize( return self._emit_ready_chunks( model_params_snapshot=model_params_snapshot, finalize=True, - ) def _emit_ready_chunks( From e81d1a3e7d390ec9fc24c6aac0b531e7916bb9a5 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 08:33:20 +0200 Subject: [PATCH 41/45] improvement of openai adapter following the comments --- src/mmirage/core/process/batch/collector.py | 3 + .../core/process/batch/openai_adapter.py | 111 ++++++++---------- 2 files changed, 54 insertions(+), 60 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index fa9682e..c8929bd 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -230,6 +230,9 @@ def main(argv: Sequence[str] | None = None) -> int: rows = collect_and_merge(records, provider_configs, args.output_path) print(f"Merged {len(rows)} rows and saved to {args.output_path}") + except ValueError as exc: + logger.error(str(exc)) + return 1 except Exception as exc: logger.exception("Collector failed") return 1 diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index ff7a45d..a490e24 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -6,6 +6,7 @@ import json import mimetypes import os +import logging from typing import Any, Dict, List, Mapping, Sequence from openai import AuthenticationError, OpenAI @@ -14,6 +15,8 @@ from mmirage.config.openai_batch import OpenAIBatchConfig from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +logger = logging.getLogger(__name__) + class OpenAIBatchAdapter(BatchSubmissionAdapter): """Provider adapter for OpenAI Batch API.""" @@ -65,36 +68,23 @@ def build_request( @staticmethod def _convert_local_images_to_data_uris(body: Dict[str, Any]) -> None: - messages = body.get("messages") - if not isinstance(messages, list): - return - - for message in messages: - if not isinstance(message, dict): - continue - - content = message.get("content") - if not isinstance(content, list): - continue - - for part in content: - if not isinstance(part, dict) or part.get("type") != "image_url": - continue - - image_url = part.get("image_url") - if not isinstance(image_url, dict): - continue - - url = image_url.get("url") - if not isinstance(url, str): - continue - - # Keep remote/data URLs untouched. - if url.startswith("http://") or url.startswith("https://") or url.startswith("data:"): - continue - - if os.path.exists(url): - image_url["url"] = OpenAIBatchAdapter._local_file_to_data_uri(url) + # If the payload shape is different, swallow the exception and leave the body untouched. + try: + for message in body["messages"]: + for part in message["content"]: + if part.get("type") != "image_url": + continue + url = part["image_url"]["url"] + if not isinstance(url, str): + continue + # Keep remote/data URLs untouched. + if url.startswith("http://") or url.startswith("https://") or url.startswith("data:"): + continue + if os.path.exists(url): + part["image_url"]["url"] = OpenAIBatchAdapter._local_file_to_data_uri(url) + except (KeyError, IndexError, TypeError, AttributeError): + # Ignore malformed shapes. + pass @staticmethod def _local_file_to_data_uri(path: str) -> str: @@ -178,13 +168,14 @@ def retrieve_results( client = self._create_client(openai_config) retrieved = client.batches.retrieve(provider_batch_id) - status = str(self._read_attr(retrieved, "status") or "unknown") + status = self._read_attr(retrieved, "status") or "unknown" output_file_id = self._read_attr(retrieved, "output_file_id") if status != "completed" or not output_file_id: raise ValueError( - f"Batch '{provider_batch_id}' is not completed or has no output file (status={status})." - ) + f"Batch '{provider_batch_id}' is not completed yet (status={status}). " + "Please retry after the provider marks it completed and produces an output file." + ) from None content_response = client.files.content(output_file_id) jsonl_text = self._extract_content_text(content_response) @@ -207,7 +198,7 @@ def parse_submission_result( self, raw_result: Mapping[str, Any], ) -> BatchSubmissionResult: - batch_id = str(raw_result.get("id") or raw_result.get("batch_id") or "") + batch_id = str(raw_result.get("id") or raw_result.get("batch_id", "")) status = raw_result.get("status", "unknown") return BatchSubmissionResult( @@ -229,34 +220,28 @@ def _check_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: @staticmethod def _extract_generated_text(row: Mapping[str, Any]) -> str: - response = row.get("response") - if not isinstance(response, Mapping): - return "" - - body = response.get("body") - if not isinstance(body, Mapping): - return "" - - choices = body.get("choices") - if not isinstance(choices, list) or not choices: - text = body.get("text") - return text if isinstance(text, str) else "" - - first = choices[0] - if isinstance(first, Mapping): - message = first.get("message") - if isinstance(message, Mapping): - content = message.get("content") - if isinstance(content, str): - return content - - text = first.get("text") + # prefer chat `message.content`, then `choices[0].text`, + # then `body.text`. Return empty string if none match. + try: + content = row["response"]["body"]["choices"][0]["message"]["content"] + if isinstance(content, str): + return content + except (KeyError, IndexError, TypeError): + pass + + try: + text = row["response"]["body"]["choices"][0]["text"] if isinstance(text, str): return text + except (KeyError, IndexError, TypeError): + pass - text = body.get("text") - if isinstance(text, str): - return text + try: + body_text = row["response"]["body"]["text"] + if isinstance(body_text, str): + return body_text + except (KeyError, TypeError): + pass return "" @@ -301,7 +286,13 @@ def _extract_content_text(content_response: Any) -> str: if isinstance(content, str): return content - raise ValueError("Unable to parse OpenAI files.content response payload.") + logger.debug( + "Unable to extract content from OpenAI files.content response; tried 'text', 'read()', and 'content' on %s", + type(content_response), + ) + raise ValueError( + "Unable to parse OpenAI files.content response: expected 'text' attribute, 'read()' method, or 'content' attribute" + ) @staticmethod def _read_attr(obj: Any, key: str) -> Any: From c32d63351873308fd4de9ca8e995394569423cc5 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 08:58:01 +0200 Subject: [PATCH 42/45] improvement of openai adapter following the comments --- .../core/process/batch/openai_adapter.py | 74 +++++++++---------- 1 file changed, 33 insertions(+), 41 deletions(-) diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index a490e24..8df7e8d 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -124,17 +124,17 @@ def submit_chunk( metadata["chunk_id"] = chunk_id batch_response = client.batches.create( - input_file_id=self._read_attr(file_response, "id"), + input_file_id=file_response.id, endpoint=openai_config.batch_endpoint, completion_window=openai_config.completion_window, metadata=metadata, ) return { - "id": self._read_attr(batch_response, "id"), - "status": self._read_attr(batch_response, "status"), - "endpoint": self._read_attr(batch_response, "endpoint"), - "input_file_id": self._read_attr(file_response, "id"), + "id": batch_response.id, + "status": getattr(batch_response, "status", None), + "endpoint": getattr(batch_response, "endpoint", None), + "input_file_id": file_response.id, "chunk_id": chunk_id, } @@ -145,13 +145,8 @@ def check_batch_status( ) -> BatchSubmissionResult: openai_config = self._check_openai_config(config) client = self._create_client(openai_config) - retrieved = client.batches.retrieve(provider_batch_id) - raw_result = { - "id": self._read_attr(retrieved, "id"), - "status": self._read_attr(retrieved, "status"), - } - return self.parse_submission_result(raw_result=raw_result) + return self.parse_submission_result(raw_result=retrieved) def retrieve_results( self, @@ -168,8 +163,8 @@ def retrieve_results( client = self._create_client(openai_config) retrieved = client.batches.retrieve(provider_batch_id) - status = self._read_attr(retrieved, "status") or "unknown" - output_file_id = self._read_attr(retrieved, "output_file_id") + status = getattr(retrieved, "status", None) or "unknown" + output_file_id = getattr(retrieved, "output_file_id", None) if status != "completed" or not output_file_id: raise ValueError( @@ -198,8 +193,20 @@ def parse_submission_result( self, raw_result: Mapping[str, Any], ) -> BatchSubmissionResult: - batch_id = str(raw_result.get("id") or raw_result.get("batch_id", "")) - status = raw_result.get("status", "unknown") + # Prefer attribute access for OpenAI SDK objects, fall back to mapping access. + def _attr_or_get(obj: Any, attr: str, default: Any = None) -> Any: + try: + val = getattr(obj, attr) + except Exception: + val = None + if val is not None: + return val + if isinstance(obj, Mapping): + return obj.get(attr, default) + return default + + batch_id = str(_attr_or_get(raw_result, "id") or _attr_or_get(raw_result, "batch_id", "")) + status = _attr_or_get(raw_result, "status", "unknown") return BatchSubmissionResult( provider_batch_id=batch_id, @@ -247,9 +254,7 @@ def _extract_generated_text(row: Mapping[str, Any]) -> str: @staticmethod def _create_client(config: OpenAIBatchConfig) -> OpenAI: - api_key = config.credentials.get("api_key", "").strip() - if not api_key: - api_key = os.environ.get("OPENAI_API_KEY", "").strip() + api_key = (config.credentials.get("api_key", "").strip() or os.environ.get("OPENAI_API_KEY", "").strip() ) if not api_key: raise ValueError( @@ -268,34 +273,21 @@ def _create_client(config: OpenAIBatchConfig) -> OpenAI: @staticmethod def _extract_content_text(content_response: Any) -> str: - text = getattr(content_response, "text", None) + # Assume `content_response` is an httpx.Response (OpenAI SDK v1). + # Prefer `.text`, fallback to `.content` bytes decode. + try: + text = content_response.text + except Exception: + text = None + if isinstance(text, str): return text - read = getattr(content_response, "read", None) - if callable(read): - data = read() - if isinstance(data, bytes): - return data.decode("utf-8") - if isinstance(data, str): - return data - content = getattr(content_response, "content", None) if isinstance(content, bytes): return content.decode("utf-8") - if isinstance(content, str): - return content - logger.debug( - "Unable to extract content from OpenAI files.content response; tried 'text', 'read()', and 'content' on %s", - type(content_response), - ) - raise ValueError( - "Unable to parse OpenAI files.content response: expected 'text' attribute, 'read()' method, or 'content' attribute" - ) + logger.debug("Unable to extract content from response of type %s", type(content_response)) + raise ValueError("Unable to parse OpenAI files.content response: missing text or content bytes") - @staticmethod - def _read_attr(obj: Any, key: str) -> Any: - if isinstance(obj, Mapping): - return obj.get(key) - return getattr(obj, key) + # _read_attr removed: code now expects OpenAI SDK v1 response objects with attributes. From 0298f7675661f659af690cb065dfef12e3d5f14b Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 10:42:14 +0200 Subject: [PATCH 43/45] syntax improved following the comments + test recommended by a copilot comment --- src/mmirage/core/process/batch/adapter.py | 14 +++---- .../core/process/batch/openai_adapter.py | 10 ++--- .../process/processors/llm/llm_processor.py | 19 ++++----- tests/test_batch_orchestrator.py | 39 +++++++++++++++++++ 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py index 614dfc3..69479a7 100644 --- a/src/mmirage/core/process/batch/adapter.py +++ b/src/mmirage/core/process/batch/adapter.py @@ -6,7 +6,7 @@ import abc from dataclasses import dataclass, field -from typing import Any, Dict, Mapping, Sequence, Tuple +from typing import Any, Dict, Sequence, Tuple from mmirage.config.batch_provider import BatchProviderConfig @@ -23,7 +23,7 @@ class BatchSubmissionResult: provider_batch_id: str status: str - raw_response: Mapping[str, Any] = field(default_factory=dict) + raw_response: Dict[str, Any] = field(default_factory=dict) class BatchSubmissionAdapter(abc.ABC): @@ -41,7 +41,7 @@ def build_request( custom_id: str, payload: Dict[str, Any], config: BatchProviderConfig, - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: """Build a single provider-ready request object. Args: @@ -58,7 +58,7 @@ def build_request( raise NotImplementedError() @abc.abstractmethod - def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: + def estimate_request_bytes(self, request: Dict[str, Any]) -> int: """Estimate serialized UTF-8 bytes for a request payload. The estimate must match or safely upper-bound the size produced by the @@ -77,9 +77,9 @@ def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: def submit_chunk( self, chunk_id: str, - requests: Sequence[Mapping[str, Any]], + requests: Sequence[Dict[str, Any]], config: BatchProviderConfig, - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: """Submit one pre-chunked request group to the provider. Args: @@ -95,7 +95,7 @@ def submit_chunk( @abc.abstractmethod def parse_submission_result( self, - raw_result: Mapping[str, Any], + raw_result: Dict[str, Any], ) -> BatchSubmissionResult: """Normalize provider submission output into a shared result model. diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 8df7e8d..8b28cb8 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -28,7 +28,7 @@ def build_request( custom_id: str, payload: Dict[str, Any], config: BatchProviderConfig, - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: openai_config = self._check_openai_config(config) body = copy.deepcopy(payload) expected_schema = body.get("expected_schema") @@ -97,14 +97,14 @@ def _local_file_to_data_uri(path: str) -> str: return f"data:{mime_type};base64,{encoded}" - def estimate_request_bytes(self, request: Mapping[str, Any]) -> int: + def estimate_request_bytes(self, request: Dict[str, Any]) -> int: serialized = json.dumps(request, ensure_ascii=False, separators=(",", ":")) return len(serialized.encode("utf-8")) def submit_chunk( self, chunk_id: str, - requests: Sequence[Mapping[str, Any]], + requests: Sequence[Dict[str, Any]], config: BatchProviderConfig, ) -> Dict[str, Any]: openai_config = self._check_openai_config(config) @@ -191,7 +191,7 @@ def retrieve_results( def parse_submission_result( self, - raw_result: Mapping[str, Any], + raw_result: Dict[str, Any], ) -> BatchSubmissionResult: # Prefer attribute access for OpenAI SDK objects, fall back to mapping access. def _attr_or_get(obj: Any, attr: str, default: Any = None) -> Any: @@ -226,7 +226,7 @@ def _check_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") @staticmethod - def _extract_generated_text(row: Mapping[str, Any]) -> str: + def _extract_generated_text(row: Dict[str, Any]) -> str: # prefer chat `message.content`, then `choices[0].text`, # then `body.text`. Return empty string if none match. try: diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 365c318..896538d 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -62,14 +62,15 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: """ super().__init__(engine_args, **kwargs) - batch_provider_cfg = getattr(engine_args, "batch_provider", None) + batch_provider_cfg = engine_args.batch_provider is_provider_batch_enabled = bool(batch_provider_cfg and batch_provider_cfg.enabled) # In provider-batch mode we only build payloads/metadata and should not # initialize GPU-backed SGLang runtime. - self.llm = None - self.tokenizer = None - if not is_provider_batch_enabled: + if is_provider_batch_enabled: + self.llm = None + self.tokenizer = None + else: self.llm = sgl.Engine(**asdict(engine_args.server_args)) self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, @@ -88,7 +89,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: self._setup_batch_runtime() def _setup_batch_runtime(self) -> None: - provider_cfg = getattr(self.config, "batch_provider", None) + provider_cfg = self.config.batch_provider if provider_cfg is None: return @@ -122,9 +123,8 @@ def _setup_batch_runtime(self) -> None: def _with_metadata_suffix(path: str, suffix: str, run_id: str) -> str: if not path: return "" - if path.endswith(".jsonl"): - return path[:-6] + f".{suffix}.{run_id}.jsonl" - return f"{path}.{suffix}.{run_id}.jsonl" + base_path = path.removesuffix(".jsonl") + return f"{base_path}.{suffix}.{run_id}.jsonl" @property def batch_mode_enabled(self) -> bool: @@ -385,7 +385,7 @@ def _batch_process_sample( payload=payload, config=self._batch_provider_config, ) - requests.append(dict(request)) + requests.append(request) source_indices.append(self._global_row_offset + global_i) self._text_orchestrator.add_requests( @@ -428,6 +428,7 @@ def _batch_process_sample( } if output_var.output_type == "JSON" and output_var.output_schema: payload["expected_schema"] = list(output_var.output_schema) + custom_id = self._next_custom_id(output_var.name, "multimodal") index_to_custom_id[global_i] = custom_id request = self._batch_adapter.build_request( diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py index 0b16f9b..de80340 100644 --- a/tests/test_batch_orchestrator.py +++ b/tests/test_batch_orchestrator.py @@ -199,3 +199,42 @@ def apply_chat_template(self, *args, **kwargs): assert processor.batch_mode_enabled is False assert processor._batch_adapter is None assert processor._batch_provider_config is None + + +def test_llm_processor_uses_sync_runtime_when_batch_provider_omitted(monkeypatch): + class FakeEngine: + def __init__(self, **_kwargs): + return None + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation should not run in this test") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, *args, **kwargs): + return "" + + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is False + assert isinstance(processor.llm, FakeEngine) + assert isinstance(processor.tokenizer, FakeTokenizer) + assert processor._batch_adapter is None + assert processor._batch_provider_config is None From 4d4f8f172976ef85bd11b0b11c650f0cadea3696 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 12:22:08 +0200 Subject: [PATCH 44/45] working version of llm api provider feature --- src/mmirage/core/process/batch/collector.py | 17 ++- .../core/process/batch/openai_adapter.py | 30 ++++- tests/test_batch_collector.py | 34 +++++ tests/test_openai_batch_adapter.py | 118 ++++++++++++++++++ 4 files changed, 192 insertions(+), 7 deletions(-) diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py index c8929bd..b57d8e7 100644 --- a/src/mmirage/core/process/batch/collector.py +++ b/src/mmirage/core/process/batch/collector.py @@ -129,6 +129,13 @@ def _build_output_payload(result_row: Mapping[str, Any], custom_id: str = "") -> question/answer JSON into a conversation format expected by downstream consumers. """ + error_message = str(result_row.get("error_message", "")).strip() + if error_message: + return { + "status": str(result_row.get("status", "error") or "error"), + "error_message": error_message, + } + raw_content = _extract_content_string(result_row) if not raw_content: return {"caption": ""} @@ -136,10 +143,12 @@ def _build_output_payload(result_row: Mapping[str, Any], custom_id: str = "") -> try: parsed = json.loads(raw_content) except json.JSONDecodeError: - logger.warning( - f"Failed to parse JSON for result row (custom_id={custom_id}). " - f"Treating as raw text. Content: {raw_content[:100]}" - ) + stripped_content = raw_content.lstrip() + if stripped_content.startswith(("{", "[")): + logger.warning( + f"Failed to parse JSON for result row (custom_id={custom_id}). " + f"Treating as raw text. Content: {raw_content[:100]}" + ) return {"caption": raw_content} if isinstance(parsed, dict) and ("question" in parsed or "answer" in parsed): diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py index 8b28cb8..76e3c98 100644 --- a/src/mmirage/core/process/batch/openai_adapter.py +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -31,7 +31,7 @@ def build_request( ) -> Dict[str, Any]: openai_config = self._check_openai_config(config) body = copy.deepcopy(payload) - expected_schema = body.get("expected_schema") + expected_schema = body.pop("expected_schema", None) # expected_schema needs to be popped from body before submission, as it was in the normalized request but is not an OpenAI API parameter. if expected_schema is not None and ( not isinstance(expected_schema, list) or not all(isinstance(key, str) for key in expected_schema) @@ -165,14 +165,21 @@ def retrieve_results( retrieved = client.batches.retrieve(provider_batch_id) status = getattr(retrieved, "status", None) or "unknown" output_file_id = getattr(retrieved, "output_file_id", None) + error_file_id = getattr(retrieved, "error_file_id", None) - if status != "completed" or not output_file_id: + if status != "completed": raise ValueError( f"Batch '{provider_batch_id}' is not completed yet (status={status}). " "Please retry after the provider marks it completed and produces an output file." ) from None - content_response = client.files.content(output_file_id) + content_file_id = output_file_id or error_file_id + if not content_file_id: + raise ValueError( + f"Batch '{provider_batch_id}' completed, but neither output_file_id nor error_file_id was returned." + ) from None + + content_response = client.files.content(content_file_id) jsonl_text = self._extract_content_text(content_response) rows: List[Dict[str, Any]] = [] @@ -181,6 +188,10 @@ def retrieve_results( if not raw: continue row = dict(json.loads(raw)) + error_message = self._extract_error_message(row) + if error_message: + row.setdefault("status", "error") + row["error_message"] = error_message if "generated_text" not in row: generated_text = self._extract_generated_text(row) if generated_text: @@ -252,6 +263,19 @@ def _extract_generated_text(row: Dict[str, Any]) -> str: return "" + @staticmethod + def _extract_error_message(row: Dict[str, Any]) -> str: + try: + error = row["response"]["body"]["error"] + if isinstance(error, dict): + message = error.get("message") + if isinstance(message, str): + return message + except (KeyError, TypeError): + pass + + return "" + @staticmethod def _create_client(config: OpenAIBatchConfig) -> OpenAI: api_key = (config.credentials.get("api_key", "").strip() or os.environ.get("OPENAI_API_KEY", "").strip() ) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py index f8e16cd..fab738b 100644 --- a/tests/test_batch_collector.py +++ b/tests/test_batch_collector.py @@ -670,3 +670,37 @@ def test_build_output_payload_logs_malformed_json(caplog): assert "Failed to parse JSON for result row" in caplog.text assert "custom_id=row_123" in caplog.text assert "Treating as raw text" in caplog.text + + +def test_build_output_payload_keeps_plain_text_silent(caplog): + from mmirage.core.process.batch.collector import _build_output_payload + + result_row = { + "custom_id": "caption:multimodal:1", + "generated_text": "The image features a solid orange background with the text \"A Cat\" displayed in a bold font.", + } + + with caplog.at_level("WARNING"): + output = _build_output_payload(result_row, custom_id="caption:multimodal:1") + + assert output == { + "caption": "The image features a solid orange background with the text \"A Cat\" displayed in a bold font." + } + assert "Failed to parse JSON for result row" not in caplog.text + + +def test_build_output_payload_preserves_provider_error_status(): + from mmirage.core.process.batch.collector import _build_output_payload + + result_row = { + "custom_id": "formatted_answer:text:50", + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + + output = _build_output_payload(result_row, custom_id="formatted_answer:text:50") + + assert output == { + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py index a03cc2c..6b228d1 100644 --- a/tests/test_openai_batch_adapter.py +++ b/tests/test_openai_batch_adapter.py @@ -44,6 +44,8 @@ def test_openai_build_request_injects_structured_output_format(): request = adapter.build_request(custom_id="row-002", payload=payload, config=config) + assert "expected_schema" not in request["body"] + assert request["body"]["response_format"] == { "type": "json_schema", "json_schema": { @@ -398,6 +400,68 @@ def __init__(self, **kwargs): assert rows[0]["generated_text"] == '{"question":"Q","answer":"A"}' +def test_openai_retrieve_results_normalizes_error_rows(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_error" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"id":"batch_req_1","custom_id":"formatted_answer:text:50",' + '"response":{"status_code":400,"request_id":"req_1",' + '"body":{"error":{"message":"Unrecognized request argument supplied: expected_schema",' + '"type":"invalid_request_error","param":null,"code":null}}},"error":null}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_error", config=config) + + assert rows == [ + { + "id": "batch_req_1", + "custom_id": "formatted_answer:text:50", + "response": { + "status_code": 400, + "request_id": "req_1", + "body": { + "error": { + "message": "Unrecognized request argument supplied: expected_schema", + "type": "invalid_request_error", + "param": None, + "code": None, + } + }, + }, + "error": None, + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + ] + + def test_openai_retrieve_results_raises_if_batch_not_completed(monkeypatch): from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter @@ -430,3 +494,57 @@ def content(self, output_file_id): with pytest.raises(ValueError, match="not completed"): adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + +def test_openai_retrieve_results_uses_error_file_when_output_missing(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = None + error_file_id = "file_error_1" + + return _RetrieveResp() + + captured = {} + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + + class _Files: + def content(self, output_file_id): + captured["output_file_id"] = output_file_id + + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"error":{' + '"message":"Unrecognized request argument supplied: expected_schema"}}}}\n' + ) + + return _ContentResp() + + self.files = _Files() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + assert captured["output_file_id"] == "file_error_1" + assert rows == [ + { + "custom_id": "c1", + "response": {"body": {"error": {"message": "Unrecognized request argument supplied: expected_schema"}}}, + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + ] From 645110ad9f4492ef4f25e1dac26d8d0e7039cae3 Mon Sep 17 00:00:00 2001 From: legstar67 Date: Tue, 12 May 2026 13:36:24 +0200 Subject: [PATCH 45/45] small correction --- src/mmirage/core/process/processors/llm/llm_processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 77e7fc2..8580673 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -65,6 +65,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: batch_provider_cfg = engine_args.batch_provider is_provider_batch_enabled = bool(batch_provider_cfg and batch_provider_cfg.enabled) + self._model_load_seconds: float = 0.0 # In provider-batch mode we only build payloads/metadata and should not # initialize GPU-backed SGLang runtime. @@ -77,7 +78,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: server_kwargs.update(extra) _load_start = time.monotonic() self.llm = sgl.Engine(**server_kwargs) - self._model_load_seconds: float = time.monotonic() - _load_start + self._model_load_seconds = time.monotonic() - _load_start self.tokenizer = AutoTokenizer.from_pretrained( engine_args.server_args.model_path, trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), @@ -93,7 +94,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: self._batch_request_counter = 0 self._global_row_offset = 0 self._setup_batch_runtime() - + # Cumulative token counts across all generate() calls in this processor's lifetime. self._total_input_tokens: int = 0 self._total_output_tokens: int = 0