diff --git a/configs/config_mock_openai_batch.yaml b/configs/config_mock_openai_batch.yaml new file mode 100644 index 0000000..80f5bd8 --- /dev/null +++ b/configs/config_mock_openai_batch.yaml @@ -0,0 +1,57 @@ +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-4B-Instruct-2507 + tp_size: 1 + disable_custom_all_reduce: true + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 1024 + custom_params: + chat_template_kwargs: + enable_thinking: false + batch_provider: + enabled: true + provider: openai + model: gpt-4o-mini + max_chunk_bytes: 52428800 + metadata_output_path: tests/output/batch_metadata.jsonl + credentials: + api_key: + +loading_params: + datasets: + - path: tests/mock_data/data.jsonl + type: JSONL + output_dir: tests/output/data_openai_batch + + num_shards: 1 + shard_id: 0 + batch_size: 64 + +processing_params: + inputs: + - name: text + key: text + + outputs: + - name: formatted_answer + type: llm + output_type: JSON + output_schema: + - question + - answer + prompt: | + Generate one question and its corresponding answer using the following text: + ``` + {{ text }} + ``` + + remove_columns: true + output_schema: + conversations: + - role: "user" + content: "{{ formatted_answer.question }}" + - role: "assistant" + content: "{{ formatted_answer.answer }}" \ No newline at end of file diff --git a/configs/config_mock_openai_batch_vision.yaml b/configs/config_mock_openai_batch_vision.yaml new file mode 100644 index 0000000..49f35f1 --- /dev/null +++ b/configs/config_mock_openai_batch_vision.yaml @@ -0,0 +1,47 @@ +processors: + - type: llm + server_args: + model_path: Qwen/Qwen3-VL-8B-Instruct + tp_size: 1 + trust_remote_code: true + chat_template: qwen2-vl + default_sampling_params: + temperature: 0.1 + top_p: 0.9 + max_new_tokens: 512 + batch_provider: + enabled: true + provider: openai + model: gpt-4o-mini + metadata_output_path: tests/output/batch_metadata.jsonl + credentials: + api_key: "" + +loading_params: + datasets: + - path: tests/mock_data_vision/data.jsonl + type: JSONL + output_dir: tests/output/data_openai_batch_vision + image_base_path: tests/mock_data_vision + + num_shards: 1 + shard_id: 0 + batch_size: 1 + +processing_params: + inputs: + - name: image_input + key: image + type: image + + outputs: + - name: caption + type: llm + output_type: plain + prompt: | + Describe what you see in this image in one concise sentence. + + remove_columns: false + output_schema: + image: "{{ image_input }}" + caption: "{{ caption }}" \ No newline at end of file diff --git a/src/mmirage/config/batch_provider.py b/src/mmirage/config/batch_provider.py new file mode 100644 index 0000000..deb7692 --- /dev/null +++ b/src/mmirage/config/batch_provider.py @@ -0,0 +1,95 @@ +"""Provider-agnostic batch configuration contracts. + +This module defines the shared configuration shape used by any future batch +submission provider (OpenAI, Anthropic, etc.). +""" + +from enum import Enum +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + + +class OversizedRequestPolicy(str, Enum): + """Policy for handling single requests that exceed the chunk byte limit.""" + + ISOLATE = "isolate" + REJECT = "reject" + + +@dataclass +class BatchRetryPolicy: + """Retry behavior used by provider-neutral batch submission orchestration. + + Attributes: + max_attempts: Maximum number of submission attempts for retryable errors. + initial_backoff_seconds: Delay before the first retry attempt. + backoff_multiplier: Multiplicative factor for subsequent retry delays. + """ + + max_attempts: int = 3 + initial_backoff_seconds: float = 2.0 + backoff_multiplier: float = 2.0 + + def __post_init__(self) -> None: + if self.max_attempts < 1: + raise ValueError("max_attempts must be >= 1") + if self.initial_backoff_seconds < 0: + raise ValueError("initial_backoff_seconds must be >= 0") + if self.backoff_multiplier < 1: + raise ValueError("backoff_multiplier must be >= 1") + + +@dataclass +class BatchProviderConfig: + """Shared contract for provider-specific batch configuration. + + Concrete provider configs should inherit from this dataclass and extend it + with provider-specific settings. The fields here are intentionally provider + neutral so chunking/submission orchestration can run through one typed path. + + Attributes: + provider: Provider identifier (for example, "openai" or "anthropic"). + enabled: Whether batch submission mode is enabled. + max_chunk_bytes: Maximum serialized request bytes per chunk. + Defaults to 50 MB. + max_requests_per_chunk: Optional hard cap on number of requests in a + chunk. If None, no request-count cap is enforced. + metadata_output_path: Base path where submission metadata receipts are saved. + Submission writes suffixed files like ``.text..jsonl`` and + ``.multimodal..jsonl`` from this base path. + retry_policy: Retry policy used by the shared batch layer. + oversized_request_policy: Handling policy when a single request exceeds + ``max_chunk_bytes``. ``isolate`` creates a dedicated oversized + chunk, while ``reject`` fails fast. + extras: Provider-specific knobs that do not belong in the shared fields. + credentials: Provider credentials required to submit chunks. + """ + + provider: str + enabled: bool = True + max_chunk_bytes: int = 50 * 1024 * 1024 + max_requests_per_chunk: Optional[int] = None + metadata_output_path: str = "" + retry_policy: BatchRetryPolicy = field(default_factory=BatchRetryPolicy) + oversized_request_policy: OversizedRequestPolicy | str = OversizedRequestPolicy.ISOLATE + extras: Dict[str, Any] = field(default_factory=dict) + credentials: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.provider = self.provider.strip().lower() + + if not self.provider: + raise ValueError("provider must be a non-empty string") + if self.max_chunk_bytes < 1: + raise ValueError("max_chunk_bytes must be >= 1") + if self.max_requests_per_chunk is not None and self.max_requests_per_chunk < 1: + raise ValueError("max_requests_per_chunk must be >= 1 when provided") + if isinstance(self.oversized_request_policy, str): + try: + self.oversized_request_policy = OversizedRequestPolicy( + self.oversized_request_policy.strip().lower() + ) + except ValueError as exc: + raise ValueError( + "oversized_request_policy must be either 'isolate' or 'reject'" + ) from exc diff --git a/src/mmirage/config/openai_batch.py b/src/mmirage/config/openai_batch.py new file mode 100644 index 0000000..8f233e1 --- /dev/null +++ b/src/mmirage/config/openai_batch.py @@ -0,0 +1,47 @@ +"""OpenAI-specific batch configuration.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional + +from mmirage.config.batch_provider import BatchProviderConfig + + +@dataclass +class OpenAIBatchConfig(BatchProviderConfig): + """OpenAI Batch API configuration. + + Attributes: + provider: Fixed provider identifier for OpenAI. + model: Model name used in each chat completion request body. + batch_endpoint: Target endpoint used by OpenAI batch jobs. + completion_window: OpenAI completion window value. + base_url: Optional base URL, useful for API-compatible gateways. + metadata: Metadata sent on batch creation. + """ + + provider: str = "openai" + model: str = "gpt-4.1-mini" + batch_endpoint: str = "/v1/chat/completions" + completion_window: str = "24h" + base_url: Optional[str] = None + metadata: Dict[str, str] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__post_init__() + allowed_windows = {"24h"} + if self.completion_window not in allowed_windows: + raise ValueError(f"completion_window must be one of {allowed_windows}") + + if not self.model.strip(): + raise ValueError("model must be a non-empty string") + if not self.batch_endpoint.startswith("/"): + raise ValueError("batch_endpoint must start with '/'") + + # Mirror OpenAI-specific fields into generic extras for provider-neutral consumers. + self.extras.setdefault("model", self.model) + self.extras.setdefault("batch_endpoint", self.batch_endpoint) + self.extras.setdefault("completion_window", self.completion_window) + if self.base_url: + self.extras.setdefault("base_url", self.base_url) + if self.metadata: + self.extras.setdefault("metadata", dict(self.metadata)) diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index e69c2d3..829b00e 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -5,8 +5,10 @@ import yaml import os +from mmirage.config.batch_provider import BatchProviderConfig from mmirage.config.config import MMirageConfig from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar +from mmirage.core.process.batch.provider_resolution import resolve_single_provider_config from mmirage.core.loader.base import BaseDataLoaderConfig, DataLoaderRegistry # Register built-in processors/loaders. @@ -111,12 +113,16 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: clz = ProcessorRegistry.get_output_var_cls(data["type"]) return from_dict(clz, data, config=config) + def batch_provider_hook(data: Dict[str, Any]) -> BatchProviderConfig: + return resolve_single_provider_config(data) + cfg = expand_env_vars(cfg) config = Config( type_hooks={ BaseProcessorConfig: processor_config_hook, BaseDataLoaderConfig: loader_config_hook, OutputVar: output_var_hook, + BatchProviderConfig: batch_provider_hook, } ) cfg_obj = from_dict(MMirageConfig, cast(dict, cfg), config=config) diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index 6e8a283..817332e 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -71,6 +71,10 @@ def batch_process_sample( NotImplementedError: If not implemented by subclass. """ raise NotImplementedError() + + def finalize(self) -> None: + """Optional lifecycle hook; override when a processor buffers state.""" + pass @abc.abstractmethod def get_token_counts(self) -> TokenCounts: diff --git a/src/mmirage/core/process/batch/__init__.py b/src/mmirage/core/process/batch/__init__.py new file mode 100644 index 0000000..31334ba --- /dev/null +++ b/src/mmirage/core/process/batch/__init__.py @@ -0,0 +1,28 @@ +"""Provider-agnostic batch processing contracts and registry.""" + +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.collector import collect_and_merge +from mmirage.core.process.batch.chunking import BatchRequestChunker, RequestChunk +from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter +from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry +from mmirage.core.process.batch.status_checker import ( + extract_unique_provider_batches, + run_status_checker, +) +from mmirage.config.openai_batch import OpenAIBatchConfig + +__all__ = [ + "BatchSubmissionAdapter", + "BatchSubmissionResult", + "collect_and_merge", + "BatchRequestChunker", + "RequestChunk", + "BatchSubmissionOrchestrator", + "OpenAIBatchAdapter", + "OpenAIBatchConfig", + "BatchAdapterFactory", + "BatchAdapterRegistry", + "extract_unique_provider_batches", + "run_status_checker", +] diff --git a/src/mmirage/core/process/batch/adapter.py b/src/mmirage/core/process/batch/adapter.py new file mode 100644 index 0000000..69479a7 --- /dev/null +++ b/src/mmirage/core/process/batch/adapter.py @@ -0,0 +1,151 @@ +"""Provider-agnostic batch submission adapter contracts. + +Adapters implement translation from internal request payloads into provider +request formats and normalize submission responses into a shared result shape. +""" + +import abc +from dataclasses import dataclass, field +from typing import Any, Dict, Sequence, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig + + +@dataclass +class BatchSubmissionResult: + """Normalized result returned by any provider adapter after chunk submission. + + Attributes: + provider_batch_id: Provider-side identifier for the submitted job/batch. + status: Provider submission status normalized to a short string. + raw_response: Original provider response payload for traceability. + """ + + provider_batch_id: str + status: str + raw_response: Dict[str, Any] = field(default_factory=dict) + + +class BatchSubmissionAdapter(abc.ABC): + """Abstract interface for provider-specific batch submission adapters. + + Implementations should be deterministic for request building and byte + estimation so chunk boundaries can be reproduced across retries. + """ + + required_credentials: Tuple[str, ...] = tuple() + + @abc.abstractmethod + def build_request( + self, + custom_id: str, + payload: Dict[str, Any], + config: BatchProviderConfig, + ) -> Dict[str, Any]: + """Build a single provider-ready request object. + + Args: + custom_id: Stable request identifier used to map async results back + to source rows. + payload: Provider-neutral request payload assembled by the core + processing layer. + config: Provider configuration contract that may influence request + shaping. + + Returns: + A provider-specific request object represented as a mapping. + """ + raise NotImplementedError() + + @abc.abstractmethod + def estimate_request_bytes(self, request: Dict[str, Any]) -> int: + """Estimate serialized UTF-8 bytes for a request payload. + + The estimate must match or safely upper-bound the size produced by the + serializer used for submission so chunk boundaries are enforced + correctly. + + Args: + request: Provider request object returned by ``build_request``. + + Returns: + Estimated byte size for the serialized request. + """ + raise NotImplementedError() + + @abc.abstractmethod + def submit_chunk( + self, + chunk_id: str, + requests: Sequence[Dict[str, Any]], + config: BatchProviderConfig, + ) -> Dict[str, Any]: + """Submit one pre-chunked request group to the provider. + + Args: + chunk_id: Internal chunk identifier generated by orchestration. + requests: Provider-ready request objects belonging to this chunk. + config: Provider config containing credentials and submission knobs. + + Returns: + Raw provider response payload as a mapping. + """ + raise NotImplementedError() + + @abc.abstractmethod + def parse_submission_result( + self, + raw_result: Dict[str, Any], + ) -> BatchSubmissionResult: + """Normalize provider submission output into a shared result model. + + Args: + raw_result: Raw payload returned by ``submit_chunk``. + + Returns: + A normalized ``BatchSubmissionResult`` for provider-neutral + orchestration and metadata persistence. + """ + raise NotImplementedError() + + @abc.abstractmethod + def check_batch_status( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> BatchSubmissionResult: + """Retrieve and normalize the latest status for a provider batch job. + + Args: + provider_batch_id: Provider-side batch/job identifier to query. + config: Provider configuration containing credentials and endpoint + overrides. + + Returns: + A normalized ``BatchSubmissionResult`` where ``status`` reflects the + latest provider-reported lifecycle state for the batch. + """ + raise NotImplementedError() + + @abc.abstractmethod + def retrieve_results( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> Sequence[Dict[str, Any]]: + """Download and parse completed batch results from the provider. + + Implementations should normalize each returned row into a plain mapping + and, when a text payload is available, expose it as ``generated_text`` + so downstream collectors can consume a provider-agnostic result shape. + + Args: + provider_batch_id: Provider-side batch/job identifier. + config: Provider configuration containing credentials and endpoint + overrides. + + Returns: + Sequence of parsed result rows (provider JSONL records normalized to + dictionaries) preserving provider output order. + """ + raise NotImplementedError() diff --git a/src/mmirage/core/process/batch/chunking.py b/src/mmirage/core/process/batch/chunking.py new file mode 100644 index 0000000..b4c6298 --- /dev/null +++ b/src/mmirage/core/process/batch/chunking.py @@ -0,0 +1,98 @@ +"""Provider-agnostic request chunking utilities for batch submission.""" + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Sequence + +from mmirage.config.batch_provider import BatchProviderConfig, OversizedRequestPolicy +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter + +logger = logging.getLogger(__name__) + + +@dataclass +class RequestChunk: + """Chunk of provider-ready requests with aggregate metadata.""" + + requests: List[Dict[str, Any]] + total_bytes: int + has_oversized_request: bool = False + + @property + def total_requests(self) -> int: + return len(self.requests) + + +class BatchRequestChunker: + """Split request sequences into chunks using serialized-byte limits.""" + + def __init__(self, adapter: BatchSubmissionAdapter, config: BatchProviderConfig) -> None: + self.adapter = adapter + self.config = config + + def chunk_requests(self, requests: Sequence[Dict[str, Any]]) -> List[RequestChunk]: + """Chunk requests according to max bytes, max requests, and oversize policy.""" + + chunks: List[RequestChunk] = [] + current_requests: List[Dict[str, Any]] = [] + current_total_bytes = 0 + max_chunk_bytes = self.config.max_chunk_bytes + + def append_current_chunk() -> None: + if current_requests: + chunks.append( + RequestChunk( + requests=list(current_requests), + total_bytes=current_total_bytes, + ) + ) + + for request in requests: + request_size = self.adapter.estimate_request_bytes(request) + + if request_size > max_chunk_bytes: + if self.config.oversized_request_policy is OversizedRequestPolicy.REJECT: + raise ValueError( + "Encountered oversized request: " + f"{request_size} bytes exceeds max_chunk_bytes={max_chunk_bytes}" + ) + + logger.warning( + "Encountered oversized request (%s bytes > %s); isolating into its own chunk.", + request_size, + max_chunk_bytes, + ) + + append_current_chunk() + current_requests = [] + current_total_bytes = 0 + + chunks.append( + RequestChunk( + requests=[request], + total_bytes=request_size, + has_oversized_request=True, + ) + ) + continue + + would_exceed_bytes = current_total_bytes + request_size > max_chunk_bytes + would_exceed_count = self._would_exceed_count_limit(current_requests) + + if current_requests and (would_exceed_bytes or would_exceed_count): + append_current_chunk() + current_requests = [] + current_total_bytes = 0 + + current_requests.append(request) + current_total_bytes += request_size + + if current_requests: + append_current_chunk() + + return chunks + + def _would_exceed_count_limit(self, current_requests: Sequence[Dict[str, Any]]) -> bool: + if self.config.max_requests_per_chunk is None: + return False + return len(current_requests) >= self.config.max_requests_per_chunk diff --git a/src/mmirage/core/process/batch/collector.py b/src/mmirage/core/process/batch/collector.py new file mode 100644 index 0000000..b57d8e7 --- /dev/null +++ b/src/mmirage/core/process/batch/collector.py @@ -0,0 +1,253 @@ +"""Collect provider batch receipts and merge completed rows by source index. + +The receiver consumes one or more metadata receipt files, resolves the provider +configuration for each recorded batch, fetches the provider results, and writes a +single JSONL file ordered by the original source row index. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from typing import Any, Dict, List, Mapping, MutableMapping, Sequence, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config +from mmirage.core.process.batch.metadata_utils import ( + BatchMetadataRecord, + _normalize_metadata_paths, + _read_metadata_records, +) +from mmirage.core.process.batch.provider_resolution import ( + build_all_provider_configs, + resolve_provider_configs, +) +from mmirage.core.process.batch.registry import BatchAdapterFactory + +logger = logging.getLogger(__name__) + + +def _aggregate_batch_mappings( + records: Sequence[BatchMetadataRecord], +) -> Dict[Tuple[str, str], Dict[str, int]]: + """Group source-index mappings by provider and provider batch ID. + + Later receipts for the same provider batch overwrite earlier entries for the + same custom ID, which keeps the latest parsed mapping authoritative. + """ + aggregated: Dict[Tuple[str, str], Dict[str, int]] = {} + + for record in records: + provider = record.provider + provider_batch_id = record.provider_batch_id + mapping = record.custom_id_to_source_index + + if not provider or not provider_batch_id or not mapping: + continue + + key = (provider, provider_batch_id) + aggregated.setdefault(key, {}) + + for custom_id, source_index in mapping.items(): + aggregated[key][str(custom_id)] = source_index + + return aggregated + + +def collect_and_merge( + records: Sequence[BatchMetadataRecord], + provider_configs: Mapping[str, BatchProviderConfig], + output_path: str, +) -> List[Dict[str, Any]]: + """Fetch provider outputs and write merged rows in source index order. + + Args: + records: Parsed receipt metadata containing provider batch references. + provider_configs: Provider-specific configuration keyed by normalized + provider name. + output_path: Destination JSONL path for the merged output. + + Returns: + The ordered rows that were written to disk. + + Raises: + ValueError: If a receipt references a provider that cannot be resolved. + """ + pair_to_mapping = _aggregate_batch_mappings(records) + + adapters: Dict[str, Any] = {} + pair_to_results: Dict[Tuple[str, str], Sequence[Dict[str, Any]]] = {} + + for provider, provider_batch_id in pair_to_mapping.keys(): + if provider not in provider_configs: + raise ValueError(f"No provider config found for '{provider}'.") + + if provider not in adapters: + adapters[provider] = BatchAdapterFactory.from_config(provider_configs[provider]) + + pair = (provider, provider_batch_id) + pair_to_results[pair] = adapters[provider].retrieve_results( + provider_batch_id=provider_batch_id, + config=provider_configs[provider], + ) + + indexed_rows: MutableMapping[Tuple[str, str, str], Dict[str, Any]] = {} + for pair, mapping in pair_to_mapping.items(): + results = pair_to_results.get(pair, []) + for result_row in results: + custom_id = str(result_row.get("custom_id", "")).strip() + if not custom_id or custom_id not in mapping: + continue + row_payload = _build_output_payload(result_row, custom_id=custom_id) + indexed_rows[(pair[0], pair[1], custom_id)] = { + "source_index": int(mapping[custom_id]), + "custom_id": custom_id, + **row_payload, + } + + # Sort primarily by source_index and secondarily by custom_id to ensure + # deterministic ordering when multiple rows share the same source_index. + ordered_rows = sorted( + indexed_rows.values(), key=lambda row: (row.get("source_index", 0), row.get("custom_id", "")) + ) + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + for row in ordered_rows: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + return ordered_rows + + +def _build_output_payload(result_row: Mapping[str, Any], custom_id: str = "") -> Dict[str, Any]: + """Convert provider content into the receiver's output schema. + + The collector preserves raw text for opaque generations, but maps structured + question/answer JSON into a conversation format expected by downstream + consumers. + """ + error_message = str(result_row.get("error_message", "")).strip() + if error_message: + return { + "status": str(result_row.get("status", "error") or "error"), + "error_message": error_message, + } + + raw_content = _extract_content_string(result_row) + if not raw_content: + return {"caption": ""} + + try: + parsed = json.loads(raw_content) + except json.JSONDecodeError: + stripped_content = raw_content.lstrip() + if stripped_content.startswith(("{", "[")): + logger.warning( + f"Failed to parse JSON for result row (custom_id={custom_id}). " + f"Treating as raw text. Content: {raw_content[:100]}" + ) + return {"caption": raw_content} + + if isinstance(parsed, dict) and ("question" in parsed or "answer" in parsed): + return { + "conversations": [ + { + "role": "user", + "content": str(parsed.get("question", "")), + }, + { + "role": "assistant", + "content": str(parsed.get("answer", "")), + }, + ] + } + + return {"caption": raw_content} + + +def _extract_content_string(result_row: Mapping[str, Any]) -> str: + """Return the generated text payload as a string. + + The collector treats missing content as empty output rather than a hard + failure so incomplete provider responses do not block the merge. + """ + return str(result_row.get("generated_text", "")) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Collect provider batch outputs and merge rows by source index." + ) + parser.add_argument( + "--metadata-path", + nargs="+", + help=( + "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " + "When omitted, uses metadata_output_path from the config batch_provider blocks " + "and resolves suffixed receipts like '.text..jsonl'." + ), + ) + parser.add_argument( + "--output-path", + required=True, + help="Path to write merged JSONL output.", + ) + parser.add_argument( + "--config", + required=True, + help="Path to the YAML configuration file", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + """Run the collector CLI. + + Reads receipt metadata, resolves provider configs from the supplied MMIRAGE + configuration, and writes the merged JSONL output path passed on the command + line. + """ + args = _build_arg_parser().parse_args(argv) + from mmirage.config.utils import load_mmirage_config + + try: + cfg = load_mmirage_config(args.config) + if args.metadata_path: + metadata_paths = args.metadata_path + else: + all_provider_configs = build_all_provider_configs(cfg) + metadata_paths = [ + config.metadata_output_path + for config in all_provider_configs.values() + if config.metadata_output_path + ] + metadata_paths = list(dict.fromkeys(metadata_paths)) + if not metadata_paths: + raise ValueError( + "No metadata paths provided and none found in config batch_provider blocks." + ) + metadata_paths = resolve_metadata_paths_from_config(metadata_paths) + + if not metadata_paths: + raise ValueError("No metadata paths provided and none found in config batch_provider blocks.") + + records = _read_metadata_records(metadata_paths) + provider_configs = resolve_provider_configs(records, cfg) + + rows = collect_and_merge(records, provider_configs, args.output_path) + print(f"Merged {len(rows)} rows and saved to {args.output_path}") + except ValueError as exc: + logger.error(str(exc)) + return 1 + except Exception as exc: + logger.exception("Collector failed") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/mmirage/core/process/batch/metadata_paths.py b/src/mmirage/core/process/batch/metadata_paths.py new file mode 100644 index 0000000..0915617 --- /dev/null +++ b/src/mmirage/core/process/batch/metadata_paths.py @@ -0,0 +1,41 @@ +"""Helpers for resolving batch metadata receipt paths.""" + +from __future__ import annotations + +import glob +from typing import List, Sequence + +_METADATA_SUFFIXES = ("text", "multimodal") + + +def _base_path_to_patterns(base_path: str) -> List[str]: + trimmed = base_path[:-6] if base_path.endswith(".jsonl") else base_path + return [f"{trimmed}.{suffix}.*.jsonl" for suffix in _METADATA_SUFFIXES] + + +def resolve_metadata_paths_from_config(metadata_output_paths: Sequence[str]) -> List[str]: + """Return metadata receipt paths for config-provided base paths. + + Submission writes suffixed receipts using .text..jsonl and + .multimodal..jsonl, so we expand base paths into matching globs. + """ + patterns: List[str] = [] + resolved: List[str] = [] + + for base_path in metadata_output_paths: + for pattern in _base_path_to_patterns(base_path): + patterns.append(pattern) + matches = sorted(glob.glob(pattern)) + if matches: + resolved.extend(matches) + + resolved = list(dict.fromkeys(resolved)) + if not resolved: + pattern_list = ", ".join(patterns) if patterns else "" + raise ValueError( + "No metadata receipts matched config metadata_output_path patterns. " + f"Tried: {pattern_list}. Expected suffixed files like " + "'.text..jsonl' or '.multimodal..jsonl'." + ) + + return resolved diff --git a/src/mmirage/core/process/batch/metadata_utils.py b/src/mmirage/core/process/batch/metadata_utils.py new file mode 100644 index 0000000..a582bdc --- /dev/null +++ b/src/mmirage/core/process/batch/metadata_utils.py @@ -0,0 +1,80 @@ +"""Shared helpers for batch metadata receipt files.""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Sequence + +logger = logging.getLogger(__name__) + + +@dataclass(slots=True) +class BatchMetadataRecord: + """Typed batch receipt row shared by collector and status checker.""" + + provider: str + provider_batch_id: str + custom_id_to_source_index: Dict[str, int] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, payload: Mapping[str, Any]) -> "BatchMetadataRecord": + provider = str(payload.get("provider", "")).strip().lower() + provider_batch_id = str(payload.get("provider_batch_id", "")).strip() + + raw_mapping = payload.get("custom_id_to_source_index", {}) + custom_id_to_source_index: Dict[str, int] = {} + if isinstance(raw_mapping, dict): + for custom_id, source_index in raw_mapping.items(): + try: + custom_id_to_source_index[str(custom_id)] = int(source_index) + except (TypeError, ValueError): + continue + + return cls( + provider=provider, + provider_batch_id=provider_batch_id, + custom_id_to_source_index=custom_id_to_source_index, + ) + + +def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]: + """Return metadata paths as a concrete list.""" + if isinstance(metadata_paths, str): + return [metadata_paths] + return list(metadata_paths) + + +def _read_metadata_records( + metadata_output_paths: str | Sequence[str], +) -> List[BatchMetadataRecord]: + """Load valid JSON objects from one or more receipt files. + + Malformed lines are skipped with a warning so partially written or noisy + receipt files do not stop collection. Only JSON objects are retained and + converted into typed records with required provider identifiers. + """ + records: List[BatchMetadataRecord] = [] + for metadata_output_path in _normalize_metadata_paths(metadata_output_paths): + with open(metadata_output_path, "r", encoding="utf-8") as f: + for line in f: + raw = line.strip() + if not raw: + continue + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + logger.warning( + "Skipping malformed metadata JSON line in %s: %s", + metadata_output_path, + exc, + ) + continue + # defensive check to ensure only dicts are included (useful against partial/corrupt metadata) + if isinstance(parsed, dict): + record = BatchMetadataRecord.from_mapping(parsed) + if not record.provider or not record.provider_batch_id: + continue + records.append(record) + return records \ No newline at end of file diff --git a/src/mmirage/core/process/batch/openai_adapter.py b/src/mmirage/core/process/batch/openai_adapter.py new file mode 100644 index 0000000..76e3c98 --- /dev/null +++ b/src/mmirage/core/process/batch/openai_adapter.py @@ -0,0 +1,317 @@ +"""Concrete OpenAI implementation of batch submission contracts.""" + +import base64 +import copy +import io +import json +import mimetypes +import os +import logging +from typing import Any, Dict, List, Mapping, Sequence + +from openai import AuthenticationError, OpenAI + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult + +logger = logging.getLogger(__name__) + + +class OpenAIBatchAdapter(BatchSubmissionAdapter): + """Provider adapter for OpenAI Batch API.""" + + required_credentials = ("api_key",) + + def build_request( + self, + custom_id: str, + payload: Dict[str, Any], + config: BatchProviderConfig, + ) -> Dict[str, Any]: + openai_config = self._check_openai_config(config) + body = copy.deepcopy(payload) + expected_schema = body.pop("expected_schema", None) # expected_schema needs to be popped from body before submission, as it was in the normalized request but is not an OpenAI API parameter. + if expected_schema is not None and ( + not isinstance(expected_schema, list) + or not all(isinstance(key, str) for key in expected_schema) + ): + raise ValueError( + "expected_schema must be a list of strings, " + f"got {type(expected_schema).__name__}" + ) + body.setdefault("model", openai_config.model) + self._convert_local_images_to_data_uris(body) + + if isinstance(expected_schema, list) and all(isinstance(k, str) for k in expected_schema): + properties = {key: {"type": "string"} for key in expected_schema} + body["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "strict": True, + "schema": { + "type": "object", + "properties": properties, + "required": expected_schema, + "additionalProperties": False, + }, + }, + } + + return { + "custom_id": custom_id, + "method": "POST", + "url": openai_config.batch_endpoint, + "body": body, + } + + @staticmethod + def _convert_local_images_to_data_uris(body: Dict[str, Any]) -> None: + # If the payload shape is different, swallow the exception and leave the body untouched. + try: + for message in body["messages"]: + for part in message["content"]: + if part.get("type") != "image_url": + continue + url = part["image_url"]["url"] + if not isinstance(url, str): + continue + # Keep remote/data URLs untouched. + if url.startswith("http://") or url.startswith("https://") or url.startswith("data:"): + continue + if os.path.exists(url): + part["image_url"]["url"] = OpenAIBatchAdapter._local_file_to_data_uri(url) + except (KeyError, IndexError, TypeError, AttributeError): + # Ignore malformed shapes. + pass + + @staticmethod + def _local_file_to_data_uri(path: str) -> str: + mime_type, _ = mimetypes.guess_type(path) + if not mime_type: + mime_type = "image/jpeg" + + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("utf-8") + + return f"data:{mime_type};base64,{encoded}" + + def estimate_request_bytes(self, request: Dict[str, Any]) -> int: + serialized = json.dumps(request, ensure_ascii=False, separators=(",", ":")) + return len(serialized.encode("utf-8")) + + def submit_chunk( + self, + chunk_id: str, + requests: Sequence[Dict[str, Any]], + config: BatchProviderConfig, + ) -> Dict[str, Any]: + openai_config = self._check_openai_config(config) + client = self._create_client(openai_config) + + jsonl_lines = [ + json.dumps(req, ensure_ascii=False, separators=(",", ":")) for req in requests + ] + jsonl_payload = "\n".join(jsonl_lines).encode("utf-8") + + file_response = client.files.create( + file=(f"batch_chunk-{chunk_id}.jsonl", io.BytesIO(jsonl_payload)), + purpose="batch", + ) + + metadata = dict(openai_config.metadata) + metadata["chunk_id"] = chunk_id + + batch_response = client.batches.create( + input_file_id=file_response.id, + endpoint=openai_config.batch_endpoint, + completion_window=openai_config.completion_window, + metadata=metadata, + ) + + return { + "id": batch_response.id, + "status": getattr(batch_response, "status", None), + "endpoint": getattr(batch_response, "endpoint", None), + "input_file_id": file_response.id, + "chunk_id": chunk_id, + } + + def check_batch_status( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> BatchSubmissionResult: + openai_config = self._check_openai_config(config) + client = self._create_client(openai_config) + retrieved = client.batches.retrieve(provider_batch_id) + return self.parse_submission_result(raw_result=retrieved) + + def retrieve_results( + self, + provider_batch_id: str, + config: BatchProviderConfig, + ) -> Sequence[Dict[str, Any]]: + """Download completed OpenAI batch rows and normalize text into ``generated_text``. + + OpenAI batch outputs can surface the assistant payload in nested + response bodies, so this method flattens the provider-specific shape + before returning rows to the provider-agnostic collector. + """ + openai_config = self._check_openai_config(config) + client = self._create_client(openai_config) + + retrieved = client.batches.retrieve(provider_batch_id) + status = getattr(retrieved, "status", None) or "unknown" + output_file_id = getattr(retrieved, "output_file_id", None) + error_file_id = getattr(retrieved, "error_file_id", None) + + if status != "completed": + raise ValueError( + f"Batch '{provider_batch_id}' is not completed yet (status={status}). " + "Please retry after the provider marks it completed and produces an output file." + ) from None + + content_file_id = output_file_id or error_file_id + if not content_file_id: + raise ValueError( + f"Batch '{provider_batch_id}' completed, but neither output_file_id nor error_file_id was returned." + ) from None + + content_response = client.files.content(content_file_id) + jsonl_text = self._extract_content_text(content_response) + + rows: List[Dict[str, Any]] = [] + for line in jsonl_text.splitlines(): + raw = line.strip() + if not raw: + continue + row = dict(json.loads(raw)) + error_message = self._extract_error_message(row) + if error_message: + row.setdefault("status", "error") + row["error_message"] = error_message + if "generated_text" not in row: + generated_text = self._extract_generated_text(row) + if generated_text: + row["generated_text"] = generated_text + rows.append(row) + + return rows + + def parse_submission_result( + self, + raw_result: Dict[str, Any], + ) -> BatchSubmissionResult: + # Prefer attribute access for OpenAI SDK objects, fall back to mapping access. + def _attr_or_get(obj: Any, attr: str, default: Any = None) -> Any: + try: + val = getattr(obj, attr) + except Exception: + val = None + if val is not None: + return val + if isinstance(obj, Mapping): + return obj.get(attr, default) + return default + + batch_id = str(_attr_or_get(raw_result, "id") or _attr_or_get(raw_result, "batch_id", "")) + status = _attr_or_get(raw_result, "status", "unknown") + + return BatchSubmissionResult( + provider_batch_id=batch_id, + status=status, + raw_response=raw_result, + ) + + @staticmethod + def _check_openai_config(config: BatchProviderConfig) -> OpenAIBatchConfig: + """Validate that `config` is an `OpenAIBatchConfig` and return it. + + Raises `TypeError` when the provided `config` is not an + `OpenAIBatchConfig`. + """ + if isinstance(config, OpenAIBatchConfig): + return config + raise TypeError("OpenAIBatchAdapter requires OpenAIBatchConfig") + + @staticmethod + def _extract_generated_text(row: Dict[str, Any]) -> str: + # prefer chat `message.content`, then `choices[0].text`, + # then `body.text`. Return empty string if none match. + try: + content = row["response"]["body"]["choices"][0]["message"]["content"] + if isinstance(content, str): + return content + except (KeyError, IndexError, TypeError): + pass + + try: + text = row["response"]["body"]["choices"][0]["text"] + if isinstance(text, str): + return text + except (KeyError, IndexError, TypeError): + pass + + try: + body_text = row["response"]["body"]["text"] + if isinstance(body_text, str): + return body_text + except (KeyError, TypeError): + pass + + return "" + + @staticmethod + def _extract_error_message(row: Dict[str, Any]) -> str: + try: + error = row["response"]["body"]["error"] + if isinstance(error, dict): + message = error.get("message") + if isinstance(message, str): + return message + except (KeyError, TypeError): + pass + + return "" + + @staticmethod + def _create_client(config: OpenAIBatchConfig) -> OpenAI: + api_key = (config.credentials.get("api_key", "").strip() or os.environ.get("OPENAI_API_KEY", "").strip() ) + + if not api_key: + raise ValueError( + "OpenAI API key is missing. Provide credentials.api_key or set OPENAI_API_KEY." + ) + + try: + client_kwargs = {"api_key": api_key} + if config.base_url: + client_kwargs["base_url"] = config.base_url + return OpenAI(**client_kwargs) + except AuthenticationError as exc: + raise ValueError(f"OpenAI authentication failed: {exc}") from exc + except Exception as exc: + raise ValueError(f"Failed to create OpenAI client: {exc}") from exc + + @staticmethod + def _extract_content_text(content_response: Any) -> str: + # Assume `content_response` is an httpx.Response (OpenAI SDK v1). + # Prefer `.text`, fallback to `.content` bytes decode. + try: + text = content_response.text + except Exception: + text = None + + if isinstance(text, str): + return text + + content = getattr(content_response, "content", None) + if isinstance(content, bytes): + return content.decode("utf-8") + + logger.debug("Unable to extract content from response of type %s", type(content_response)) + raise ValueError("Unable to parse OpenAI files.content response: missing text or content bytes") + + # _read_attr removed: code now expects OpenAI SDK v1 response objects with attributes. diff --git a/src/mmirage/core/process/batch/orchestrator.py b/src/mmirage/core/process/batch/orchestrator.py new file mode 100644 index 0000000..60ad6d4 --- /dev/null +++ b/src/mmirage/core/process/batch/orchestrator.py @@ -0,0 +1,185 @@ +"""Stateful provider-agnostic orchestration for batch submission.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import hashlib +import json +import os +from typing import Any, Dict, List, Mapping, Optional, Sequence + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.chunking import BatchRequestChunker, RequestChunk + + +@dataclass +class _PendingRequest: + request: Mapping[str, Any] + source_index: int # original row index of the data sample within the input dataset + + +class BatchSubmissionOrchestrator: + """Accumulate requests across map iterations and submit full-ready chunks.""" + + def __init__(self, adapter: BatchSubmissionAdapter, config: BatchProviderConfig) -> None: + self.adapter = adapter + self.config = config + self.chunker = BatchRequestChunker(adapter=adapter, config=config) + self._pending: List[_PendingRequest] = [] + self._chunk_counter = 0 + + @property + def pending_count(self) -> int: + return len(self._pending) + + def add_requests( + self, + requests: Sequence[Mapping[str, Any]], + source_indices: Sequence[int], + model_params_snapshot: Optional[Mapping[str, Any]] = None, + ) -> List[BatchSubmissionResult]: + """Append requests and submit only chunks that are ready mid-stream.""" + if len(requests) != len(source_indices): + raise ValueError("requests and source_indices must have identical lengths") + + for request, source_index in zip(requests, source_indices): + self._pending.append(_PendingRequest(request=request, source_index=source_index)) + + return self._emit_ready_chunks( + model_params_snapshot=model_params_snapshot, + finalize=False, + ) + + def finalize( + self, + model_params_snapshot: Optional[Mapping[str, Any]] = None, + ) -> List[BatchSubmissionResult]: + """Flush all remaining requests at end-of-dataset lifecycle.""" + return self._emit_ready_chunks( + model_params_snapshot=model_params_snapshot, + finalize=True, + ) + + def _emit_ready_chunks( + self, + model_params_snapshot: Optional[Mapping[str, Any]], + finalize: bool = False, + ) -> List[BatchSubmissionResult]: + if not self._pending: + return [] + + pending_requests = [entry.request for entry in self._pending] + chunks = self.chunker.chunk_requests(pending_requests) + chunk_groups = self._split_pending_entries_by_chunks(chunks) + + groups_to_submit: List[tuple[List[_PendingRequest], RequestChunk]] = [] + groups_to_keep: List[_PendingRequest] = [] + + if finalize: + groups_to_submit = chunk_groups + elif chunk_groups: + groups_to_submit = chunk_groups[:-1] + tail_entries, tail_chunk = chunk_groups[-1] + if self._is_complete_chunk(tail_chunk): + groups_to_submit.append((tail_entries, tail_chunk)) + else: + groups_to_keep = list(tail_entries) + + self._pending = groups_to_keep + + submission_results: List[BatchSubmissionResult] = [] + for chunk_entries, request_chunk in groups_to_submit: + chunk_id = self._next_chunk_id() + raw_result = self.adapter.submit_chunk( + chunk_id=chunk_id, + requests=[entry.request for entry in chunk_entries], + config=self.config, + ) + parsed_result = self.adapter.parse_submission_result( + raw_result=raw_result, + ) + submission_results.append(parsed_result) + + self._persist_metadata( + chunk_id=chunk_id, + chunk_entries=chunk_entries, + chunk=request_chunk, + parsed_result=parsed_result, + model_params_snapshot=model_params_snapshot, + flush_reason="finalize" if finalize else "full_chunk", + ) + + return submission_results + + def _split_pending_entries_by_chunks( + self, + chunks: Sequence[RequestChunk], + ) -> List[tuple[List[_PendingRequest], RequestChunk]]: + grouped: List[tuple[List[_PendingRequest], RequestChunk]] = [] + cursor = 0 + for chunk in chunks: + size = len(chunk.requests) + grouped.append((self._pending[cursor : cursor + size], chunk)) + cursor += size + return grouped + + def _is_complete_chunk(self, chunk: RequestChunk) -> bool: + if chunk.has_oversized_request: + return True + if chunk.total_bytes >= self.config.max_chunk_bytes: + return True + if self.config.max_requests_per_chunk is not None: + return chunk.total_requests >= self.config.max_requests_per_chunk + return False + + def _next_chunk_id(self) -> str: + self._chunk_counter += 1 + return f"chunk-{self._chunk_counter:06d}" + + def _persist_metadata( + self, + chunk_id: str, + chunk_entries: Sequence[_PendingRequest], + chunk: RequestChunk, + parsed_result: BatchSubmissionResult, + model_params_snapshot: Optional[Mapping[str, Any]], + flush_reason: str, + ) -> None: + if not self.config.metadata_output_path: + return + + custom_to_source = { + str(entry.request.get("custom_id", f"idx-{entry.source_index}")): entry.source_index + for entry in chunk_entries + } + + request_hash = hashlib.sha256( + json.dumps( + [entry.request for entry in chunk_entries], + ensure_ascii=False, + sort_keys=True, + separators=(",", ":"), + ).encode("utf-8") + ).hexdigest() + + metadata_record: Dict[str, Any] = { + "provider": self.config.provider, + "chunk_id": chunk_id, + "provider_batch_id": parsed_result.provider_batch_id, + "status": parsed_result.status, + "custom_id_to_source_index": custom_to_source, + "request_hash": request_hash, + "model_params_snapshot": dict(model_params_snapshot or {}), + "submitted_request_count": chunk.total_requests, + "total_bytes": chunk.total_bytes, + "has_oversized_request": chunk.has_oversized_request, + "flush_reason": flush_reason, + "submitted_at_utc": datetime.now(timezone.utc).isoformat(), + } + + metadata_path = self.config.metadata_output_path + os.makedirs(os.path.dirname(metadata_path) or ".", exist_ok=True) + with open(metadata_path, "a", encoding="utf-8") as f: + f.write(json.dumps(metadata_record, ensure_ascii=False) + "\n") diff --git a/src/mmirage/core/process/batch/provider_resolution.py b/src/mmirage/core/process/batch/provider_resolution.py new file mode 100644 index 0000000..d28042e --- /dev/null +++ b/src/mmirage/core/process/batch/provider_resolution.py @@ -0,0 +1,205 @@ +"""Resolve batch provider configs from YAML and metadata inputs. + +These helpers decouple config parsing from metadata inspection so the same +provider configuration logic can be reused by receiver and submission flows. +""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Sequence, Type + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.metadata_utils import BatchMetadataRecord + +if TYPE_CHECKING: + from mmirage.config.config import MMirageConfig + + +class BatchProviderConfigRegistry: + """Registry for provider-specific batch config classes. + + Built-ins are registered lazily to avoid import cycles and keep config + resolution available in lightweight contexts. + """ + + _registry: Dict[str, Type[BatchProviderConfig]] = {} + _bootstrapped: bool = False + + @classmethod + def _bootstrap_builtin_configs(cls) -> None: + if cls._bootstrapped: + return + + from mmirage.config.openai_batch import OpenAIBatchConfig + + cls.register("openai", OpenAIBatchConfig) + cls._bootstrapped = True + + @classmethod + def register(cls, provider: str, config_cls: Type[BatchProviderConfig]) -> None: + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + cls._registry[provider_key] = config_cls + + @classmethod + def clear(cls) -> None: + cls._registry.clear() + cls._bootstrapped = False + + @classmethod + def get_config_cls( + cls, + provider: str, + default: Type[BatchProviderConfig] | None = None, + ) -> Type[BatchProviderConfig]: + cls._bootstrap_builtin_configs() + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + if provider_key in cls._registry: + return cls._registry[provider_key] + if default is not None: + return default + raise ValueError( + f"Unknown batch provider '{provider}'. Available providers: {list(cls._registry.keys())}" + ) + + +def _discover_required_providers(metadata_records: Sequence[BatchMetadataRecord]) -> List[str]: + providers: List[str] = [] + seen = set() + for record in metadata_records: + provider = record.provider + if not provider or provider in seen: + continue + seen.add(provider) + providers.append(provider) + return providers + + +def _extract_batch_provider_blocks(cfg: MMirageConfig) -> Dict[str, Dict[str, Any]]: + """Collect raw batch_provider blocks keyed by provider. + + Raises ValueError on duplicate provider definitions to avoid ambiguous + config resolution. + """ + provider_blocks: Dict[str, Dict[str, Any]] = {} + for processor_cfg in cfg.processors: + raw_provider = getattr(processor_cfg, "batch_provider", None) + if raw_provider is None: + continue + + if isinstance(raw_provider, BatchProviderConfig): + raw_block = asdict(raw_provider) + else: + raw_block = dict(raw_provider or {}) + + if not raw_block: + continue + + provider = str(raw_block.get("provider", "openai")).strip().lower() + if not provider: + continue + + if provider in provider_blocks: + raise ValueError( + f"Duplicate batch_provider blocks found for provider '{provider}' in config processors." + ) + + provider_blocks[provider] = raw_block + + return provider_blocks + + +def _instantiate_provider_config(provider: str, raw_block: Mapping[str, Any]) -> BatchProviderConfig: + """Instantiate the provider config, falling back to the shared base config.""" + payload = dict(raw_block) + payload.setdefault("provider", provider) + + config_cls = BatchProviderConfigRegistry.get_config_cls( + provider, + default=BatchProviderConfig, + ) + return config_cls(**payload) + + +def resolve_single_provider_config(raw_block: Mapping[str, Any]) -> BatchProviderConfig: + """Resolve a single provider config from a raw batch_provider block. + + Defaults to the OpenAI provider for backward compatibility and raises + ValueError for unknown providers or invalid config payloads. + """ + payload = dict(raw_block or {}) + provider = str(payload.get("provider", "openai")).strip().lower() + if not provider: + provider = "openai" + payload["provider"] = provider + + try: + BatchProviderConfigRegistry.get_config_cls(provider) + except ValueError as exc: + raise ValueError(str(exc)) from exc + + try: + return _instantiate_provider_config(provider, payload) + except Exception as exc: + raise ValueError( + f"Failed to instantiate batch provider config for '{provider}': {exc}" + ) from exc + + +def build_all_provider_configs(cfg: "MMirageConfig") -> Dict[str, BatchProviderConfig]: + """Build provider configs for every batch_provider block in the YAML. + + Raises ValueError when any provider config fails to instantiate. + """ + provider_blocks = _extract_batch_provider_blocks(cfg) + if not provider_blocks: + return {} + + resolved: Dict[str, BatchProviderConfig] = {} + for provider, raw_block in provider_blocks.items(): + try: + resolved[provider] = _instantiate_provider_config(provider, raw_block) + except Exception as exc: + raise ValueError( + f"Failed to instantiate batch provider config for '{provider}': {exc}" + ) from exc + + return resolved + + +def resolve_provider_configs( + metadata_records: Sequence[BatchMetadataRecord], + cfg: "MMirageConfig", +) -> Dict[str, BatchProviderConfig]: + """Resolve provider configs required by receiver metadata. + + Args: + metadata_records: Parsed metadata JSONL records used to discover which + providers are required by the receiver command. + cfg: Loaded YAML config object from ``load_mmirage_config``. + + Returns: + Mapping from normalized provider name to instantiated provider config. + + Raises: + ValueError: If metadata references a provider missing from config + processor ``batch_provider`` blocks or if provider config + instantiation fails. + """ + available_configs = build_all_provider_configs(cfg) + required_providers = _discover_required_providers(metadata_records) + if not required_providers: + return {} + + missing = [provider for provider in required_providers if provider not in available_configs] + if missing: + raise ValueError( + "Metadata references provider(s) missing from YAML batch_provider config: " + f"{missing}. Check cfg.processors[*].batch_provider." + ) + + return {provider: available_configs[provider] for provider in required_providers} diff --git a/src/mmirage/core/process/batch/registry.py b/src/mmirage/core/process/batch/registry.py new file mode 100644 index 0000000..862585f --- /dev/null +++ b/src/mmirage/core/process/batch/registry.py @@ -0,0 +1,92 @@ +"""Registry and factory for provider batch adapters.""" + +import os +from typing import Dict, Type + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter + + +class BatchAdapterRegistry: + """Provider-to-adapter registry with factory helpers. + + This class centralizes provider registration and fail-fast adapter + instantiation with credential validation. + """ + + _registry: Dict[str, Type[BatchSubmissionAdapter]] = dict() + _bootstrapped: bool = False + + @classmethod + def _bootstrap_builtin_adapters(cls) -> None: + if cls._bootstrapped: + return + + # Local import avoids import cycles while ensuring built-ins are available. + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + cls.register("openai", OpenAIBatchAdapter) + cls._bootstrapped = True + + @classmethod + def register(cls, provider: str, adapter_cls: Type[BatchSubmissionAdapter]) -> None: + """Register an adapter class under a provider key.""" + provider_key = provider.strip().lower() + if not provider_key: + raise ValueError("provider must be a non-empty string") + cls._registry[provider_key] = adapter_cls + + @classmethod + def clear(cls) -> None: + """Clear all registered adapters. + + Intended for tests and isolated bootstrapping logic. + """ + cls._registry.clear() + cls._bootstrapped = False + + @classmethod + def resolve(cls, provider: str) -> Type[BatchSubmissionAdapter]: + """Resolve a provider key to a registered adapter class.""" + cls._bootstrap_builtin_adapters() + provider_key = provider.strip().lower() + if provider_key not in cls._registry: + raise ValueError( + f"Unknown batch provider '{provider}'. " + f"Available providers: {list(cls._registry.keys())}" + ) + return cls._registry[provider_key] + + @classmethod + def create(cls, config: BatchProviderConfig) -> BatchSubmissionAdapter: + """Instantiate an adapter for a provider config with credential checks.""" + adapter_cls = cls.resolve(config.provider) + + missing_credentials = [] + for req_key in adapter_cls.required_credentials: + credential_value = (config.credentials.get(req_key, "") or "").strip() + if credential_value: + continue + + env_var = f"{config.provider.upper()}_{req_key.upper()}" + env_value = (os.environ.get(env_var, "") or "").strip() + if env_value: + config.credentials[req_key] = env_value + continue + + missing_credentials.append(req_key) + + if missing_credentials: + raise ValueError( + f"Missing credentials for provider '{config.provider}': {missing_credentials}" + ) + return adapter_cls() + + +class BatchAdapterFactory: + """Compatibility alias around registry-based adapter creation.""" + + @classmethod + def from_config(cls, config: BatchProviderConfig) -> BatchSubmissionAdapter: + """Create an adapter from provider config via registry resolution.""" + return BatchAdapterRegistry.create(config) diff --git a/src/mmirage/core/process/batch/status_checker.py b/src/mmirage/core/process/batch/status_checker.py new file mode 100644 index 0000000..7ddbd71 --- /dev/null +++ b/src/mmirage/core/process/batch/status_checker.py @@ -0,0 +1,165 @@ +"""Receiver-side helper to check provider batch status from metadata receipts. + +Designed for CLI use against JSONL receipt files. Skips malformed lines and +missing keys to keep status checks resilient to partial metadata corruption. +""" + +from __future__ import annotations + +import argparse +import logging +import sys +from typing import Any, Dict, List, Mapping, Sequence, TextIO, Tuple + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.metadata_paths import resolve_metadata_paths_from_config +from mmirage.core.process.batch.metadata_utils import ( + BatchMetadataRecord, + _normalize_metadata_paths, + _read_metadata_records, +) +from mmirage.core.process.batch.provider_resolution import ( + build_all_provider_configs, + resolve_provider_configs, +) +from mmirage.core.process.batch.registry import BatchAdapterFactory +logger = logging.getLogger(__name__) + + +def extract_unique_provider_batches( + metadata_records: Sequence[BatchMetadataRecord], +) -> List[Tuple[str, str]]: + """Return unique ``(provider, provider_batch_id)`` pairs. + + Normalizes provider names to lowercase and ignores records that do not + provide both keys, preventing accidental calls with incomplete metadata. + """ + unique_pairs: List[Tuple[str, str]] = [] + seen = set() + + for record in metadata_records: + provider = record.provider + provider_batch_id = record.provider_batch_id + + pair = (provider, provider_batch_id) + if pair in seen: + continue + seen.add(pair) + unique_pairs.append(pair) + + return unique_pairs + + +def run_status_checker( + metadata_records: Sequence[BatchMetadataRecord], + provider_configs: Mapping[str, BatchProviderConfig], +) -> List[BatchSubmissionResult]: + """Check batch status for each referenced provider batch. + + Prints a per-batch line and a per-provider summary. Providers missing + from ``provider_configs`` are skipped rather than failing the run so + partial configurations still yield useful status output. + """ + results: List[BatchSubmissionResult] = [] + counter: Dict[str, Dict[str, int]] = {} + + for provider, provider_batch_id in extract_unique_provider_batches(metadata_records): + if provider not in provider_configs: + logger.warning(f"Skipping batch {provider_batch_id}: no config for provider '{provider}'.") + provider_counts = counter.setdefault(provider, {}) + provider_counts["skipped"] = provider_counts.get("skipped", 0) + 1 + + else: + config = provider_configs[provider] + adapter = BatchAdapterFactory.from_config(config) + result = adapter.check_batch_status(provider_batch_id=provider_batch_id, config=config) + results.append(result) + + logger.info(f"Batch {provider_batch_id} ({provider}): {result.status}") + provider_counts = counter.setdefault(provider, {}) + provider_counts[result.status] = provider_counts.get(result.status, 0) + 1 + + print("\n------------ Batch status summary ------------") + for provider, status_counts in counter.items(): + total = sum(status_counts.values()) + print(f"Provider '{provider}' (Total: {total}):") + for status, count in status_counts.items(): + print(f" - {status}: {count}/{total}") + + return results + + +def _build_arg_parser() -> argparse.ArgumentParser: + """Build the CLI parser for the status-check entry point.""" + parser = argparse.ArgumentParser(description="Check provider batch statuses from metadata receipts.") + parser.add_argument( + "--metadata-path", + nargs="+", + help=( + "Path(s) to metadata JSONL receipt file(s). Supports multiple files. " + "When omitted, uses metadata_output_path from the config batch_provider blocks " + "and resolves suffixed receipts like '.text..jsonl'." + ), + ) + parser.add_argument( + "--config", + required=True, + help="Path to the YAML configuration file", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + """CLI entry point that returns a process-style status code. + + Returns 0 on success or no batches found, and 1 on configuration or + provider resolution failures. + """ + args = _build_arg_parser().parse_args(argv) + from mmirage.config.utils import load_mmirage_config + + try: + cfg = load_mmirage_config(args.config) + if args.metadata_path: + metadata_paths = args.metadata_path + else: + all_provider_configs = build_all_provider_configs(cfg) + metadata_paths = [ + config.metadata_output_path + for config in all_provider_configs.values() + if config.metadata_output_path + ] + metadata_paths = list(dict.fromkeys(metadata_paths)) + if not metadata_paths: + logger.error("No metadata paths provided and none found in config batch_provider blocks.") + return 1 + metadata_paths = resolve_metadata_paths_from_config(metadata_paths) + + if not metadata_paths: + logger.error("No metadata paths provided and none found in config batch_provider blocks.") + return 1 + + records = _read_metadata_records(metadata_paths) + pairs = extract_unique_provider_batches(records) + if not pairs: + logger.info(f"No provider batch IDs found in metadata file(s): {metadata_paths}") + return 0 + + provider_configs = resolve_provider_configs(records, cfg) + if not provider_configs: + logger.error("No supported provider configurations could be built from metadata.") + return 1 + run_status_checker( + metadata_records=records, + provider_configs=provider_configs, + ) + except Exception as exc: + logger.exception("Status checker failed") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 877741b..4571054 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -131,3 +131,8 @@ def get_load_time(self) -> float: if hasattr(proc, "get_load_time"): total += proc.get_load_time() return total + + def finalize_processors(self) -> None: + """Finalize processors that expose a finalize lifecycle hook.""" + for processor in self.processors.values(): + processor.finalize() diff --git a/src/mmirage/core/process/processors/llm/config.py b/src/mmirage/core/process/processors/llm/config.py index dde3029..ae94d05 100644 --- a/src/mmirage/core/process/processors/llm/config.py +++ b/src/mmirage/core/process/processors/llm/config.py @@ -7,6 +7,7 @@ from typing import Dict, Optional, Sequence, Type, Any, List from pydantic import BaseModel, create_model +from mmirage.config.batch_provider import BatchProviderConfig from mmirage.core.process.variables import BaseVar, OutputVar from mmirage.core.process.base import BaseProcessorConfig @@ -87,11 +88,13 @@ class SGLangLLMConfig(BaseProcessorConfig): server_args: SGLang server arguments including model path and TP size. default_sampling_params: Default sampling parameters for generation. chat_template: Chat template name for vision-language models (e.g., "qwen2-vl"). + batch_provider: Optional provider batch settings for async submission. """ server_args: SGLangServerArgs = field(default_factory=SGLangServerArgs) default_sampling_params: Dict[str, Any] = field(default_factory=dict) chat_template: str = "" # Empty means use tokenizer's default + batch_provider: Optional[BatchProviderConfig] = None @dataclass diff --git a/src/mmirage/core/process/processors/llm/llm_processor.py b/src/mmirage/core/process/processors/llm/llm_processor.py index 19cae67..c17c68f 100644 --- a/src/mmirage/core/process/processors/llm/llm_processor.py +++ b/src/mmirage/core/process/processors/llm/llm_processor.py @@ -2,11 +2,12 @@ from __future__ import annotations -from dataclasses import asdict +from dataclasses import asdict, replace import json import logging import time -from typing import Any, List, Tuple +from typing import Any, Dict, List, Optional, Tuple +import uuid import jinja2 try: @@ -17,6 +18,8 @@ from transformers import AutoTokenizer from mmirage.core.process.base import BaseProcessor, ProcessorRegistry, TokenCounts +from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator +from mmirage.core.process.batch.registry import BatchAdapterFactory from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig from mmirage.core.process.variables import VariableEnvironment @@ -63,24 +66,93 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None: **kwargs: Additional arguments passed to base class. """ super().__init__(engine_args, **kwargs) - if not SGLANG_AVAILABLE: - raise RuntimeError( - "SGLang is not installed. Install with: pip install 'mmirage[gpu]' " - "or, from a source checkout, pip install -e '.[gpu]'" + + batch_provider_cfg = engine_args.batch_provider + is_provider_batch_enabled = bool(batch_provider_cfg and batch_provider_cfg.enabled) + self._model_load_seconds: float = 0.0 + + # In provider-batch mode we only build payloads/metadata and should not + # initialize GPU-backed SGLang runtime. + if is_provider_batch_enabled: + self.llm = None + self.tokenizer = None + else: + if not SGLANG_AVAILABLE: + raise RuntimeError( + "SGLang is not installed. Install with: pip install 'mmirage[gpu]' " + "or, from a source checkout, pip install -e '.[gpu]'" + ) + + server_kwargs = asdict(engine_args.server_args) + extra = server_kwargs.pop("extra_engine_args", {}) or {} + server_kwargs.update(extra) + _load_start = time.monotonic() + self.llm = sgl.Engine(**server_kwargs) + self._model_load_seconds = time.monotonic() - _load_start + self.tokenizer = AutoTokenizer.from_pretrained( + engine_args.server_args.model_path, + trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), ) - server_kwargs = asdict(engine_args.server_args) - extra = server_kwargs.pop("extra_engine_args", {}) or {} - server_kwargs.update(extra) - _load_start = time.monotonic() - self.llm = sgl.Engine(**server_kwargs) - self._model_load_seconds: float = time.monotonic() - _load_start - self.tokenizer = AutoTokenizer.from_pretrained( - engine_args.server_args.model_path, - trust_remote_code=getattr(engine_args.server_args, "trust_remote_code", False), - ) self.sampling_params = engine_args.default_sampling_params self.chat_template = engine_args.chat_template + self._batch_adapter = None + self._batch_provider_config = None + self._text_orchestrator: Optional[BatchSubmissionOrchestrator] = None + self._multimodal_orchestrator: Optional[BatchSubmissionOrchestrator] = None + self._batch_request_counter = 0 + self._global_row_offset = 0 + self._setup_batch_runtime() + + # Cumulative token counts across all generate() calls in this processor's lifetime. + self._total_input_tokens: int = 0 + self._total_output_tokens: int = 0 + + def _setup_batch_runtime(self) -> None: + provider_cfg = self.config.batch_provider + if provider_cfg is None: + return + + if not provider_cfg.enabled: + return + + self._batch_provider_config = provider_cfg + self._batch_adapter = BatchAdapterFactory.from_config(provider_cfg) + run_id = uuid.uuid4().hex[:6] + + self._text_orchestrator = BatchSubmissionOrchestrator( + adapter=self._batch_adapter, + config=replace( + provider_cfg, + metadata_output_path=self._with_metadata_suffix( + provider_cfg.metadata_output_path, "text", run_id + ), + ), + ) + self._multimodal_orchestrator = BatchSubmissionOrchestrator( + adapter=self._batch_adapter, + config=replace( + provider_cfg, + metadata_output_path=self._with_metadata_suffix( + provider_cfg.metadata_output_path, "multimodal", run_id + ), + ), + ) + + @staticmethod + def _with_metadata_suffix(path: str, suffix: str, run_id: str) -> str: + if not path: + return "" + base_path = path.removesuffix(".jsonl") + return f"{base_path}.{suffix}.{run_id}.jsonl" + + @property + def batch_mode_enabled(self) -> bool: + return self._text_orchestrator is not None and self._multimodal_orchestrator is not None + + def _next_custom_id(self, output_name: str, modality: str) -> str: + self._batch_request_counter += 1 + return f"{output_name}:{modality}:{self._batch_request_counter}" # Cumulative token counts across all generate() calls in this processor's lifetime. self._total_input_tokens: int = 0 self._total_output_tokens: int = 0 @@ -189,6 +261,9 @@ def batch_process_sample( """ nb_samples = len(batch) + if self.batch_mode_enabled: + return self._batch_process_sample(batch=batch, output_var=output_var) + # Prepare sampling params sampling_params_output = self.sampling_params.copy() @@ -316,8 +391,148 @@ def batch_process_sample( return [results[i] for i in range(nb_samples)] + def _batch_process_sample( + self, + batch: List[VariableEnvironment], + output_var: LLMOutputVar, + ) -> List[VariableEnvironment]: + assert self._batch_provider_config is not None + assert self._batch_adapter is not None + assert self._text_orchestrator is not None + assert self._multimodal_orchestrator is not None + + nb_samples = len(batch) + text_only_indices: List[int] = [] + multimodal_indices: List[int] = [] + index_to_custom_id: Dict[int, str] = {} + for i in range(nb_samples): + if batch[i].has_images(): + multimodal_indices.append(i) + else: + text_only_indices.append(i) + + if text_only_indices: + jinja_template = jinja2.Template(output_var.prompt) + requests: List[Dict[str, Any]] = [] + source_indices: List[int] = [] + for global_i in text_only_indices: + base_prompt = jinja_template.render(**batch[global_i].to_dict()) + payload = { + "messages": [ + { + "role": "user", + "content": base_prompt, + } + ] + } + if output_var.output_type == "JSON" and output_var.output_schema: + payload["expected_schema"] = list(output_var.output_schema) + custom_id = self._next_custom_id(output_var.name, "text") + index_to_custom_id[global_i] = custom_id + request = self._batch_adapter.build_request( + custom_id=custom_id, + payload=payload, + config=self._batch_provider_config, + ) + requests.append(request) + source_indices.append(self._global_row_offset + global_i) + + self._text_orchestrator.add_requests( + requests=requests, + source_indices=source_indices, + model_params_snapshot={ + "output_name": output_var.name, + "output_type": output_var.output_type, + "modality": "text", + }, + ) + + if multimodal_indices: + requests = [] + source_indices = [] + for global_i in multimodal_indices: + base_prompt, image_data = self.build_multimodal_prompt(output_var.prompt, batch[global_i]) + content: List[Dict[str, Any]] = [{"type": "text", "text": base_prompt}] + + if image_data is not None: + if isinstance(image_data, list): + images = image_data + else: + images = [image_data] + for image_ref in images: + content.append( + { + "type": "image_url", + "image_url": {"url": str(image_ref)}, + } + ) + + payload = { + "messages": [ + { + "role": "user", + "content": content, + } + ] + } + if output_var.output_type == "JSON" and output_var.output_schema: + payload["expected_schema"] = list(output_var.output_schema) + + custom_id = self._next_custom_id(output_var.name, "multimodal") + index_to_custom_id[global_i] = custom_id + request = self._batch_adapter.build_request( + custom_id=custom_id, + payload=payload, + config=self._batch_provider_config, + ) + requests.append(dict(request)) + source_indices.append(self._global_row_offset + global_i) + + self._multimodal_orchestrator.add_requests( + requests=requests, + source_indices=source_indices, + model_params_snapshot={ + "output_name": output_var.name, + "output_type": output_var.output_type, + "modality": "multimodal", + }, + ) + + placeholders: List[VariableEnvironment] = [] + for i in range(nb_samples): + unique_id = index_to_custom_id.get(i, f"unknown:{i}") + placeholder = f"__BATCH_SUBMITTED__:{unique_id}" + placeholders.append(batch[i].with_variable(output_var.name, placeholder)) + + self._global_row_offset += nb_samples + + return placeholders + + def finalize(self) -> None: + if not self.batch_mode_enabled: + return + + assert self._text_orchestrator is not None + assert self._multimodal_orchestrator is not None + + self._text_orchestrator.finalize( + model_params_snapshot={ + "modality": "text", + "phase": "finalize", + } + ) + self._multimodal_orchestrator.finalize( + model_params_snapshot={ + "modality": "multimodal", + "phase": "finalize", + } + ) + def shutdown(self) -> None: """Shutdown the LLM engine.""" + if self.llm is None: + return + try: self.llm.shutdown() except Exception as e: diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index d12e07d..30f427b 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -15,6 +15,7 @@ from mmirage.core.loader.utils import load_datasets_from_configs from mmirage.core.process.mapper import MMIRAGEMapper from mmirage.core.writer.renderer import TemplateRenderer + from mmirage.shard_utils import ( GpuUtilizationPoller, ShardStats, @@ -177,6 +178,8 @@ def main(): }, remove_columns=remove_columns, ) + # Drain stateful batch accumulators once this dataset map iteration finishes. + mapper.finalize_processors() ds_processed_all.append(ds_processed) for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): diff --git a/tests/mock_data_vision/data.jsonl b/tests/mock_data_vision/data.jsonl index 0b29b01..7f67b97 100644 --- a/tests/mock_data_vision/data.jsonl +++ b/tests/mock_data_vision/data.jsonl @@ -8,5 +8,4 @@ {"image": "beach.jpg"} {"image": "cat.jpg"} {"image": "dog.jpg"} -{"image": "mountain.jpg"} -{"image": "beach.jpg"} + diff --git a/tests/test_batch_adapter_contracts.py b/tests/test_batch_adapter_contracts.py new file mode 100644 index 0000000..85890ef --- /dev/null +++ b/tests/test_batch_adapter_contracts.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass + +import pytest + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.provider_resolution import ( + BatchProviderConfigRegistry, + resolve_single_provider_config, +) +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry + + +class CompleteTestAdapter(BatchSubmissionAdapter): + required_credentials = tuple() + + @property + def adapter_name(self) -> str: + return "complete-test-adapter" + + @property + def adapter_version(self) -> str: + return "1.0.0" + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, "payload": dict(payload), "provider": config.provider} + + def estimate_request_bytes(self, request): + # Deterministic approximation for tests. + return len(str(request).encode("utf-8")) + + def submit_chunk(self, chunk_id, requests, config): + return { + "batch_id": f"{config.provider}-{chunk_id}", + "status": "submitted", + "requests": len(requests), + } + + def parse_submission_result(self, raw_result): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["batch_id"]), + status=str(raw_result["status"]), + raw_response=raw_result, + ) + + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + + +class CredentialedTestAdapter(CompleteTestAdapter): + required_credentials = ("api_key",) + + +class IncompleteTestAdapter(BatchSubmissionAdapter): + @property + def adapter_name(self) -> str: + return "incomplete" + + @property + def adapter_version(self) -> str: + return "0.0.0" + + def build_request(self, custom_id, payload, config): + return {} + + def estimate_request_bytes(self, request): + return 0 + + def submit_chunk(self, chunk_id, requests, config): + return {} + + +@pytest.fixture(autouse=True) +def clear_batch_adapter_registry(): + BatchAdapterRegistry.clear() + yield + BatchAdapterRegistry.clear() + + +@pytest.fixture(autouse=True) +def clear_batch_provider_registry(): + BatchProviderConfigRegistry.clear() + yield + BatchProviderConfigRegistry.clear() + + +def test_adapter_interface_is_abstract(): + with pytest.raises(TypeError): + BatchSubmissionAdapter() + + +def test_incomplete_adapter_fails_interface_compliance(): + with pytest.raises(TypeError): + IncompleteTestAdapter() + + +def test_complete_adapter_is_interface_compliant(): + adapter = CompleteTestAdapter() + config = BatchProviderConfig(provider="unit") + + request = adapter.build_request(custom_id="req-1", payload={"x": 1}, config=config) + assert request["custom_id"] == "req-1" + + estimated_bytes = adapter.estimate_request_bytes(request) + assert estimated_bytes > 0 + + raw_result = adapter.submit_chunk(chunk_id="chunk-1", requests=[request], config=config) + parsed = adapter.parse_submission_result(raw_result=raw_result) + + assert parsed.provider_batch_id == "unit-chunk-1" + assert parsed.status == "submitted" + + +def test_factory_resolves_registered_provider(): + BatchAdapterRegistry.register("unit", CompleteTestAdapter) + config = BatchProviderConfig(provider="unit") + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CompleteTestAdapter) + + +def test_factory_raises_for_unknown_provider(): + config = BatchProviderConfig(provider="not-registered") + + with pytest.raises(ValueError, match="Unknown batch provider"): + BatchAdapterFactory.from_config(config) + + +def test_factory_raises_for_missing_required_credentials(): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + config = BatchProviderConfig(provider="unit", credentials={}) + + with pytest.raises(ValueError, match="Missing credentials"): + BatchAdapterFactory.from_config(config) + + +def test_factory_creates_adapter_when_credentials_are_present(): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + config = BatchProviderConfig(provider="unit", credentials={"api_key": "secret"}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CredentialedTestAdapter) + + +def test_factory_resolves_missing_credentials_from_environment(monkeypatch): + BatchAdapterRegistry.register("unit", CredentialedTestAdapter) + monkeypatch.setenv("UNIT_API_KEY", "from-env") + config = BatchProviderConfig(provider="unit", credentials={}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, CredentialedTestAdapter) + assert config.credentials["api_key"] == "from-env" + + +@dataclass +class UnitBatchConfig(BatchProviderConfig): + provider: str = "unit" + unit_setting: str = "default" + + def __post_init__(self) -> None: + super().__post_init__() + if not self.unit_setting.strip(): + raise ValueError("unit_setting must be a non-empty string") + + +def test_resolve_single_provider_config_defaults_to_openai(): + config = resolve_single_provider_config({}) + + assert isinstance(config, OpenAIBatchConfig) + assert config.provider == "openai" + + +def test_resolve_single_provider_config_resolves_custom_provider(): + BatchProviderConfigRegistry.register("unit", UnitBatchConfig) + + config = resolve_single_provider_config( + {"provider": "unit", "unit_setting": "custom"} + ) + + assert isinstance(config, UnitBatchConfig) + assert config.provider == "unit" + assert config.unit_setting == "custom" + + +def test_resolve_single_provider_config_raises_for_unknown_provider(): + with pytest.raises(ValueError, match="Unknown batch provider"): + resolve_single_provider_config({"provider": "not-registered"}) diff --git a/tests/test_batch_chunking.py b/tests/test_batch_chunking.py new file mode 100644 index 0000000..fcceb8c --- /dev/null +++ b/tests/test_batch_chunking.py @@ -0,0 +1,140 @@ +import pytest + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult + + +class SizeAwareTestAdapter(BatchSubmissionAdapter): + def __init__(self) -> None: + self.estimate_calls = [] + + @property + def adapter_name(self) -> str: + return "size-aware-test-adapter" + + @property + def adapter_version(self) -> str: + return "1.0.0" + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, **dict(payload)} + + def estimate_request_bytes(self, request): + size = int(request["size_bytes"]) + self.estimate_calls.append(size) + return size + + def submit_chunk(self, chunk_id, requests, config): + return {"id": chunk_id, "status": "submitted"} + + def parse_submission_result(self, raw_result): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["id"]), + status=str(raw_result["status"]), + raw_response=raw_result, + ) + + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + + +def _sizes_from_chunks(chunks): + return [[request["size_bytes"] for request in chunk.requests] for chunk in chunks] + + +def test_chunker_splits_when_byte_limit_is_reached(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 4}, + {"custom_id": "r2", "size_bytes": 4}, + {"custom_id": "r3", "size_bytes": 4}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[4, 4], [4]] + assert [chunk.total_bytes for chunk in chunks] == [8, 4] + assert adapter.estimate_calls == [4, 4, 4] + + +def test_chunker_splits_when_max_requests_per_chunk_is_reached(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10_000, + max_requests_per_chunk=2, + ) + requests = [ + {"custom_id": "r1", "size_bytes": 1}, + {"custom_id": "r2", "size_bytes": 1}, + {"custom_id": "r3", "size_bytes": 1}, + {"custom_id": "r4", "size_bytes": 1}, + {"custom_id": "r5", "size_bytes": 1}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[1, 1], [1, 1], [1]] + assert [chunk.total_requests for chunk in chunks] == [2, 2, 1] + + +def test_chunker_honors_exact_byte_boundary_without_flushing_early(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 6}, + {"custom_id": "r2", "size_bytes": 4}, + {"custom_id": "r3", "size_bytes": 1}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[6, 4], [1]] + assert [chunk.total_bytes for chunk in chunks] == [10, 1] + + +def test_chunker_isolates_oversized_single_request_by_default(caplog): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig(provider="unit", max_chunk_bytes=10) + requests = [ + {"custom_id": "r1", "size_bytes": 3}, + {"custom_id": "r2", "size_bytes": 25}, + {"custom_id": "r3", "size_bytes": 3}, + ] + + chunks = BatchRequestChunker(adapter, config).chunk_requests(requests) + + assert _sizes_from_chunks(chunks) == [[3], [25], [3]] + assert [chunk.has_oversized_request for chunk in chunks] == [False, True, False] + assert "oversized request" in caplog.text.lower() + + +def test_chunker_rejects_oversized_single_request_when_policy_is_reject(): + from mmirage.core.process.batch.chunking import BatchRequestChunker + + adapter = SizeAwareTestAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + oversized_request_policy="reject", + ) + requests = [{"custom_id": "r1", "size_bytes": 11}] + + with pytest.raises(ValueError, match="oversized request"): + BatchRequestChunker(adapter, config).chunk_requests(requests) diff --git a/tests/test_batch_collector.py b/tests/test_batch_collector.py new file mode 100644 index 0000000..fab738b --- /dev/null +++ b/tests/test_batch_collector.py @@ -0,0 +1,706 @@ +import json +from types import SimpleNamespace + +import pytest +from mmirage.config.openai_batch import OpenAIBatchConfig + + +def test_collect_and_merge_reconstructs_rows_deterministically(tmp_path, monkeypatch, caplog): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 2, "c2": 0}, + } + ), + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 2, "c2": 0}, + } + ), + "malformed-line", + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_2", + "custom_id_to_source_index": {"c3": 1}, + } + ), + ] + ), + encoding="utf-8", + ) + + output_path = tmp_path / "merged.jsonl" + + class FakeAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + if provider_batch_id == "batch_1": + return [ + { + "custom_id": "c1", + "generated_text": '{"question":"q2","answer":"a2"}', + }, + { + "custom_id": "c2", + "generated_text": '{"question":"q0","answer":"a0"}', + }, + ] + return [ + { + "custom_id": "c3", + "generated_text": '{"question":"q1","answer":"a1"}', + } + ] + + fake_adapter = FakeAdapter() + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: fake_adapter, + ) + + provider_configs = {"openai": OpenAIBatchConfig(credentials={"api_key": "k"})} + records = _read_metadata_records(str(metadata_path)) + assert "Skipping malformed metadata JSON line" in caplog.text + rows = collect_and_merge( + records=records, + provider_configs=provider_configs, + output_path=str(output_path), + ) + + assert [r["source_index"] for r in rows] == [0, 1, 2] + assert [r["custom_id"] for r in rows] == ["c2", "c3", "c1"] + assert [r["conversations"][0]["content"] for r in rows] == ["q0", "q1", "q2"] + assert [r["conversations"][1]["content"] for r in rows] == ["a0", "a1", "a2"] + assert fake_adapter.calls == [("batch_1", "openai"), ("batch_2", "openai")] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert [r["source_index"] for r in written] == [0, 1, 2] + assert [r["conversations"][0]["content"] for r in written] == ["q0", "q1", "q2"] + + +def test_collect_and_merge_raises_for_missing_provider_config(tmp_path): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + try: + records = _read_metadata_records(str(metadata_path)) + collect_and_merge( + records=records, + provider_configs={}, + output_path=str(tmp_path / "out.jsonl"), + ) + assert False, "Expected ValueError" + except ValueError as e: + assert "No provider config" in str(e) + + +def test_collect_and_merge_outputs_caption_for_plain_text_content(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_plain", + "custom_id_to_source_index": {"img_1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_plain.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + return [ + { + "custom_id": "img_1", + "generated_text": "A black cat sitting on a sofa.", + } + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert rows == [ + { + "source_index": 0, + "custom_id": "img_1", + "caption": "A black cat sitting on a sofa.", + } + ] + + +def test_collect_and_merge_keeps_rows_with_duplicate_custom_ids_across_batches( + tmp_path, monkeypatch +): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"shared": 0}, + } + ), + json.dumps( + { + "provider": "unit", + "provider_batch_id": "batch_unit", + "custom_id_to_source_index": {"shared": 1}, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_duplicates.jsonl" + + class OpenAIAdapter: + def retrieve_results(self, provider_batch_id, config): + return [{"custom_id": "shared", "generated_text": "openai"}] + + class UnitAdapter: + def retrieve_results(self, provider_batch_id, config): + return [{"custom_id": "shared", "generated_text": "unit"}] + + adapters = { + "openai": OpenAIAdapter(), + "unit": UnitAdapter(), + } + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: adapters[config.provider], + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={ + "openai": SimpleNamespace(provider="openai"), + "unit": SimpleNamespace(provider="unit"), + }, + output_path=str(output_path), + ) + + assert [row["source_index"] for row in rows] == [0, 1] + assert [row["custom_id"] for row in rows] == ["shared", "shared"] + assert [row["caption"] for row in rows] == ["openai", "unit"] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert written == rows + + +def test_collect_and_merge_uses_openai_adapter_generated_text(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"o1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_openai.jsonl" + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_1" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"custom_id":"o1","response":{"body":{"choices":[{"message":{"content":"{\\"question\\":\\"Q\\",\\"answer\\":\\"A\\"}"}}]}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: OpenAIBatchAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert rows == [ + { + "source_index": 0, + "custom_id": "o1", + "conversations": [ + {"role": "user", "content": "Q"}, + {"role": "assistant", "content": "A"}, + ], + } + ] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert written == rows + + +def test_collector_main_uses_config_and_records(tmp_path, monkeypatch): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_main", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + captured = {} + + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + def _fake_collect_and_merge(records, provider_configs, output_path_arg): + captured["records"] = records + captured["provider_configs"] = provider_configs + captured["output_path"] = output_path_arg + return [{"source_index": 0, "custom_id": "c1", "caption": "ok"}] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.collect_and_merge", + _fake_collect_and_merge, + ) + + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(captured["records"]) == 1 + assert captured["records"][0].provider == "openai" + assert "openai" in captured["provider_configs"] + assert captured["output_path"] == str(output_path) + + +def test_collector_main_uses_config_metadata_path_when_missing_cli_arg( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import collector + + metadata_base = tmp_path / "batch_metadata.jsonl" + metadata_path = tmp_path / "batch_metadata.text.abc123.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_main", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + captured = {} + + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + def _fake_collect_and_merge(records, provider_configs, output_path_arg): + captured["records"] = records + captured["provider_configs"] = provider_configs + captured["output_path"] = output_path_arg + return [{"source_index": 0, "custom_id": "c1", "caption": "ok"}] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.collect_and_merge", + _fake_collect_and_merge, + ) + + rc = collector.main( + [ + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(captured["records"]) == 1 + assert captured["records"][0].provider == "openai" + assert "openai" in captured["provider_configs"] + assert captured["output_path"] == str(output_path) + + +def test_collector_main_raises_when_config_metadata_paths_missing(tmp_path, monkeypatch, caplog): + from mmirage.core.process.batch import collector + + metadata_base = tmp_path / "batch_metadata.jsonl" + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + rc = collector.main( + [ + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "No metadata receipts matched config metadata_output_path patterns" in caplog.text + + +def test_collector_main_raises_when_metadata_provider_missing_in_config(tmp_path, monkeypatch, caplog): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "mistral", + "provider_batch_id": "batch_mistral", + "custom_id_to_source_index": {"m1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + # Config intentionally only defines openai, not mistral. + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "missing from YAML batch_provider config" in caplog.text + + +def test_collect_and_merge_routes_multiple_providers(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + from mmirage.core.process.batch.provider_resolution import resolve_provider_configs + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_openai", + "custom_id_to_source_index": {"o1": 1}, + } + ), + json.dumps( + { + "provider": "unit", + "provider_batch_id": "batch_unit", + "custom_id_to_source_index": {"u1": 0}, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={"provider": "openai", "credentials": {"api_key": "k"}} + ), + SimpleNamespace(batch_provider={"provider": "unit"}), + ] + ) + + records = _read_metadata_records(str(metadata_path)) + provider_configs = resolve_provider_configs(records, cfg) + + class OpenAIAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + return [{"custom_id": "o1", "generated_text": "openai"}] + + class UnitAdapter: + def __init__(self): + self.calls = [] + + def retrieve_results(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + return [{"custom_id": "u1", "generated_text": "unit"}] + + adapters = { + "openai": OpenAIAdapter(), + "unit": UnitAdapter(), + } + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: adapters[config.provider], + ) + + output_path = tmp_path / "merged.jsonl" + rows = collect_and_merge( + records=records, + provider_configs=provider_configs, + output_path=str(output_path), + ) + + assert [row["custom_id"] for row in rows] == ["u1", "o1"] + assert [row["caption"] for row in rows] == ["unit", "openai"] + assert ("batch_openai", "openai") in adapters["openai"].calls + assert ("batch_unit", "unit") in adapters["unit"].calls + + +def test_collector_main_raises_for_invalid_batch_provider_config(tmp_path, monkeypatch, caplog): + from mmirage.core.process.batch import collector + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_1", + "custom_id_to_source_index": {"c1": 0}, + } + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace(batch_provider={"provider": "openai", "batch_endpoint": "v1"}) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + rc = collector.main( + [ + "--metadata-path", + str(metadata_path), + "--output-path", + str(output_path), + "--config", + str(config_path), + ] + ) + assert rc == 1 + assert "batch_endpoint must start with '/'" in caplog.text + + +def test_collect_and_merge_tiebreaker_secondary_sort_key(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_tie", + "custom_id_to_source_index": {"a": 0, "b": 0, "c": 1}, + } + ) + + "\n", + encoding="utf-8", + ) + + output_path = tmp_path / "merged_tie.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + # Return rows intentionally out-of-order to ensure collector sorts + # deterministically using the secondary key. + return [ + {"custom_id": "b", "generated_text": "B"}, + {"custom_id": "a", "generated_text": "A"}, + {"custom_id": "c", "generated_text": "C"}, + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "k"})}, + output_path=str(output_path), + ) + + assert [r["custom_id"] for r in rows] == ["a", "b", "c"] + + +def test_build_output_payload_logs_malformed_json(caplog): + from mmirage.core.process.batch.collector import _build_output_payload + + malformed_json = '{"question": "incomplete' + result_row = { + "custom_id": "row_123", + "generated_text": malformed_json, + } + + with caplog.at_level("WARNING"): + output = _build_output_payload(result_row, custom_id="row_123") + + assert output == {"caption": malformed_json} + assert "Failed to parse JSON for result row" in caplog.text + assert "custom_id=row_123" in caplog.text + assert "Treating as raw text" in caplog.text + + +def test_build_output_payload_keeps_plain_text_silent(caplog): + from mmirage.core.process.batch.collector import _build_output_payload + + result_row = { + "custom_id": "caption:multimodal:1", + "generated_text": "The image features a solid orange background with the text \"A Cat\" displayed in a bold font.", + } + + with caplog.at_level("WARNING"): + output = _build_output_payload(result_row, custom_id="caption:multimodal:1") + + assert output == { + "caption": "The image features a solid orange background with the text \"A Cat\" displayed in a bold font." + } + assert "Failed to parse JSON for result row" not in caplog.text + + +def test_build_output_payload_preserves_provider_error_status(): + from mmirage.core.process.batch.collector import _build_output_payload + + result_row = { + "custom_id": "formatted_answer:text:50", + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + + output = _build_output_payload(result_row, custom_id="formatted_answer:text:50") + + assert output == { + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } diff --git a/tests/test_batch_orchestrator.py b/tests/test_batch_orchestrator.py new file mode 100644 index 0000000..de80340 --- /dev/null +++ b/tests/test_batch_orchestrator.py @@ -0,0 +1,240 @@ +from dataclasses import dataclass +import json + +import pytest + +from mmirage.config.batch_provider import BatchProviderConfig +from mmirage.core.process.batch.adapter import BatchSubmissionAdapter, BatchSubmissionResult +from mmirage.core.process.batch.provider_resolution import BatchProviderConfigRegistry +from mmirage.core.process.batch.registry import BatchAdapterRegistry +from mmirage.core.process.base import ProcessorRegistry +from mmirage.core.process.processors.llm.config import SGLangLLMConfig, SGLangServerArgs + + +class RecordingAdapter(BatchSubmissionAdapter): + def __init__(self) -> None: + self.submissions = [] + + def build_request(self, custom_id, payload, config): + return {"custom_id": custom_id, **dict(payload)} + + def estimate_request_bytes(self, request): + return int(request["size_bytes"]) + + def submit_chunk(self, chunk_id, requests, config): + self.submissions.append( + { + "chunk_id": chunk_id, + "requests": list(requests), + } + ) + return {"id": f"batch-{chunk_id}", "status": "submitted"} + + def parse_submission_result(self, raw_result): + return BatchSubmissionResult( + provider_batch_id=str(raw_result["id"]), + status=str(raw_result["status"]), + raw_response=raw_result, + ) + + def check_batch_status(self, provider_batch_id, config): + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status="submitted", + raw_response={"id": provider_batch_id, "status": "submitted"}, + ) + + def retrieve_results(self, provider_batch_id, config): + return [] + + +@pytest.fixture(autouse=True) +def clear_batch_registries(): + BatchProviderConfigRegistry.clear() + BatchAdapterRegistry.clear() + yield + BatchProviderConfigRegistry.clear() + BatchAdapterRegistry.clear() + + +def test_orchestrator_buffers_across_iterations_and_avoids_tiny_midstream_flush(tmp_path): + from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator + + adapter = RecordingAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + metadata_output_path=str(tmp_path / "metadata.jsonl"), + ) + orchestrator = BatchSubmissionOrchestrator(adapter=adapter, config=config) + + # Iteration 1: only 9 bytes total, should remain buffered and submit nothing. + r1 = [{"custom_id": "a", "size_bytes": 6}, {"custom_id": "b", "size_bytes": 3}] + out1 = orchestrator.add_requests(r1, [10, 11], {"phase": "iter1"}) + assert out1 == [] + assert len(adapter.submissions) == 0 + + # Iteration 2: appending 2 bytes should emit one full chunk [6,3] and keep [2]. + r2 = [{"custom_id": "c", "size_bytes": 2}] + out2 = orchestrator.add_requests(r2, [12], {"phase": "iter2"}) + assert len(out2) == 1 + assert len(adapter.submissions) == 1 + assert [x["size_bytes"] for x in adapter.submissions[0]["requests"]] == [6, 3] + assert orchestrator.pending_count == 1 + + # Finalize: emits the remaining tiny tail exactly once. + out3 = orchestrator.finalize({"phase": "finalize"}) + assert len(out3) == 1 + assert len(adapter.submissions) == 2 + assert [x["size_bytes"] for x in adapter.submissions[1]["requests"]] == [2] + assert orchestrator.pending_count == 0 + + +def test_orchestrator_writes_provider_neutral_metadata_with_flush_reason(tmp_path): + from mmirage.core.process.batch.orchestrator import BatchSubmissionOrchestrator + + metadata_path = tmp_path / "batch_metadata.jsonl" + adapter = RecordingAdapter() + config = BatchProviderConfig( + provider="unit", + max_chunk_bytes=10, + metadata_output_path=str(metadata_path), + ) + orchestrator = BatchSubmissionOrchestrator(adapter=adapter, config=config) + + orchestrator.add_requests( + requests=[ + {"custom_id": "x1", "size_bytes": 8}, + {"custom_id": "x2", "size_bytes": 8}, + ], + source_indices=[0, 1], + model_params_snapshot={"model": "unit-model"}, + ) + orchestrator.finalize({"model": "unit-model"}) + + lines = metadata_path.read_text(encoding="utf-8").strip().splitlines() + assert len(lines) == 2 + + first = json.loads(lines[0]) + second = json.loads(lines[1]) + + assert first["provider"] == "unit" + assert first["flush_reason"] == "full_chunk" + assert first["custom_id_to_source_index"] == {"x1": 0} + assert isinstance(first["request_hash"], str) and len(first["request_hash"]) == 64 + + assert second["flush_reason"] == "finalize" + assert second["custom_id_to_source_index"] == {"x2": 1} + assert second["provider_batch_id"].startswith("batch-chunk-") + + +@dataclass +class UnitBatchConfig(BatchProviderConfig): + provider: str = "unit" + unit_setting: str = "default" + + def __post_init__(self) -> None: + super().__post_init__() + if not self.unit_setting.strip(): + raise ValueError("unit_setting must be a non-empty string") + + +def test_llm_processor_initializes_with_custom_provider(tmp_path): + BatchProviderConfigRegistry.register("unit", UnitBatchConfig) + BatchAdapterRegistry.register("unit", RecordingAdapter) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider=UnitBatchConfig( + provider="unit", + unit_setting="custom", + metadata_output_path=str(tmp_path / "metadata.jsonl"), + ), + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is True + assert isinstance(processor._batch_provider_config, UnitBatchConfig) + assert processor._batch_provider_config.provider == "unit" + assert processor._batch_provider_config.unit_setting == "custom" + assert isinstance(processor._batch_adapter, RecordingAdapter) + + +def test_llm_processor_skips_batch_setup_when_disabled(monkeypatch): + class FakeEngine: + def __init__(self, **_kwargs): + return None + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation should not run in this test") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, *args, **kwargs): + return "" + + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider=BatchProviderConfig(provider="openai", enabled=False), + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is False + assert processor._batch_adapter is None + assert processor._batch_provider_config is None + + +def test_llm_processor_uses_sync_runtime_when_batch_provider_omitted(monkeypatch): + class FakeEngine: + def __init__(self, **_kwargs): + return None + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation should not run in this test") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, *args, **kwargs): + return "" + + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + config = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + ) + + processor_cls = ProcessorRegistry.get_processor("llm") + processor = processor_cls(config) + + assert processor.batch_mode_enabled is False + assert isinstance(processor.llm, FakeEngine) + assert isinstance(processor.tokenizer, FakeTokenizer) + assert processor._batch_adapter is None + assert processor._batch_provider_config is None diff --git a/tests/test_batch_status_checker.py b/tests/test_batch_status_checker.py new file mode 100644 index 0000000..f661f77 --- /dev/null +++ b/tests/test_batch_status_checker.py @@ -0,0 +1,274 @@ +from types import SimpleNamespace + +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult + + +def test_extract_unique_provider_batches_handles_malformed_and_duplicates(tmp_path): + from mmirage.core.process.batch.status_checker import ( + _read_metadata_records, + extract_unique_provider_batches, + ) + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + '{"provider":"openai","provider_batch_id":"batch_1"}', + '{"provider":"openai","provider_batch_id":"batch_1"}', + "not-json", + '{"provider":"openai"}', + '{"provider":"openai","provider_batch_id":"batch_2"}', + "", + ] + ), + encoding="utf-8", + ) + + pairs = extract_unique_provider_batches(_read_metadata_records(str(metadata_path))) + + assert pairs == [("openai", "batch_1"), ("openai", "batch_2")] + + +def test_run_status_checker_prints_summary_with_factory_dispatch(tmp_path, monkeypatch): + from mmirage.core.process.batch.status_checker import _read_metadata_records, run_status_checker + from unittest.mock import patch + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + "\n".join( + [ + '{"provider":"openai","provider_batch_id":"batch_1"}', + '{"provider":"openai","provider_batch_id":"batch_2"}', + '{"provider":"openai","provider_batch_id":"batch_1"}', + ] + ), + encoding="utf-8", + ) + + class FakeAdapter: + def __init__(self): + self.calls = [] + + def check_batch_status(self, provider_batch_id, config): + self.calls.append((provider_batch_id, config.provider)) + status = "completed" if provider_batch_id == "batch_1" else "in_progress" + return BatchSubmissionResult( + provider_batch_id=provider_batch_id, + status=status, + raw_response={"id": provider_batch_id, "status": status}, + ) + + fake_adapter = FakeAdapter() + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.BatchAdapterFactory.from_config", + lambda config: fake_adapter, + ) + + config_map = { + "openai": OpenAIBatchConfig(credentials={"api_key": "k"}), + } + records = _read_metadata_records(str(metadata_path)) + + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + results = run_status_checker( + metadata_records=records, + provider_configs=config_map, + ) + + assert [(r.provider_batch_id, r.status) for r in results] == [ + ("batch_1", "completed"), + ("batch_2", "in_progress"), + ] + assert fake_adapter.calls == [ + ("batch_1", "openai"), + ("batch_2", "openai"), + ] + + # Verify logger.info was called with expected messages + logger_calls = [call[0][0] for call in mock_logger.info.call_args_list] + assert any("Batch batch_1 (openai): completed" in str(call) for call in logger_calls) + assert any("Batch batch_2 (openai): in_progress" in str(call) for call in logger_calls) + + +def test_status_checker_main_uses_config_and_runs(tmp_path, monkeypatch): + from mmirage.core.process.batch import status_checker + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + called = {} + + def _fake_run_status_checker(metadata_records, provider_configs, output=None): + called["metadata_records"] = metadata_records + called["provider_configs"] = provider_configs + return [] + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.run_status_checker", + _fake_run_status_checker, + ) + + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 0 + assert len(called["metadata_records"]) == 1 + assert called["metadata_records"][0].provider == "openai" + assert "openai" in called["provider_configs"] + + +def test_status_checker_main_returns_error_when_metadata_provider_missing_in_config( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import status_checker + from unittest.mock import patch + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"mistral","provider_batch_id":"batch_m1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + # Config intentionally only defines openai, not mistral. + cfg = SimpleNamespace(processors=[SimpleNamespace(batch_provider={"provider": "openai"})]) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 1 + assert mock_logger.error.called or mock_logger.exception.called + + +def test_status_checker_main_returns_error_when_credentials_missing( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import status_checker + from unittest.mock import patch + + metadata_path = tmp_path / "receipts.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[SimpleNamespace(batch_provider={"provider": "openai", "credentials": {}})] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main( + [ + "--metadata-path", + str(metadata_path), + "--config", + str(config_path), + ] + ) + + assert rc == 1 + assert mock_logger.error.called or mock_logger.exception.called + + +def test_status_checker_main_uses_config_metadata_path_when_missing_cli_arg( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import status_checker + + metadata_base = tmp_path / "batch_metadata.jsonl" + metadata_path = tmp_path / "batch_metadata.text.abc123.jsonl" + metadata_path.write_text( + '{"provider":"openai","provider_batch_id":"batch_1"}\n', + encoding="utf-8", + ) + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + called = {} + + def _fake_run_status_checker(metadata_records, provider_configs, output=None): + called["metadata_records"] = metadata_records + called["provider_configs"] = provider_configs + return [] + + monkeypatch.setattr( + "mmirage.core.process.batch.status_checker.run_status_checker", + _fake_run_status_checker, + ) + + rc = status_checker.main(["--config", str(config_path)]) + + assert rc == 0 + assert len(called["metadata_records"]) == 1 + assert called["metadata_records"][0].provider == "openai" + assert "openai" in called["provider_configs"] + + +def test_status_checker_main_returns_error_when_config_metadata_paths_missing( + tmp_path, monkeypatch +): + from mmirage.core.process.batch import status_checker + from unittest.mock import patch + + metadata_base = tmp_path / "batch_metadata.jsonl" + config_path = tmp_path / "dummy.yaml" + config_path.write_text("processors: []\n", encoding="utf-8") + + cfg = SimpleNamespace( + processors=[ + SimpleNamespace( + batch_provider={ + "provider": "openai", + "metadata_output_path": str(metadata_base), + } + ) + ] + ) + monkeypatch.setattr("mmirage.config.utils.load_mmirage_config", lambda path: cfg) + + with patch("mmirage.core.process.batch.status_checker.logger") as mock_logger: + rc = status_checker.main(["--config", str(config_path)]) + + assert rc == 1 + assert mock_logger.exception.called diff --git a/tests/test_integration_batch_pipeline.py b/tests/test_integration_batch_pipeline.py new file mode 100644 index 0000000..6928aa7 --- /dev/null +++ b/tests/test_integration_batch_pipeline.py @@ -0,0 +1,157 @@ +import json +from pathlib import Path + +from datasets import load_dataset + +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process import LLMProcessor # Ensures processor registration. +from mmirage.core.process.mapper import MMIRAGEMapper +from mmirage.core.process.processors.llm.config import LLMOutputVar, SGLangLLMConfig, SGLangServerArgs +from mmirage.core.process.variables import InputVar +from mmirage.core.writer.renderer import TemplateRenderer + + +def test_integration_batch_pipeline_with_stateful_accumulator(monkeypatch, tmp_path): + captured = { + "file_uploads": [], + "batch_creates": [], + "engine_init_calls": 0, + } + + class FakeFiles: + def create(self, *, file, purpose): + file_name, file_obj = file + payload = file_obj.read().decode("utf-8") + captured["file_uploads"].append( + { + "file_name": file_name, + "purpose": purpose, + "payload": payload, + } + ) + + class _FileResp: + id = f"file_{len(captured['file_uploads'])}" + + return _FileResp() + + class FakeBatches: + def create(self, **kwargs): + captured["batch_creates"].append(kwargs) + + class _BatchResp: + id = f"batch_{len(captured['batch_creates'])}" + status = "validating" + endpoint = kwargs["endpoint"] + + return _BatchResp() + + class FakeOpenAIClient: + def __init__(self, **_kwargs): + self.files = FakeFiles() + self.batches = FakeBatches() + + class FakeEngine: + def __init__(self, **_kwargs): + captured["engine_init_calls"] += 1 + + def generate(self, **_kwargs): + raise AssertionError("Synchronous generation path should not run in batch mode") + + def shutdown(self): + return None + + class FakeTokenizer: + def apply_chat_template(self, user_prompt, tokenize=False, add_generation_prompt=True): + assert tokenize is False + assert add_generation_prompt is True + return user_prompt[0]["content"] + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeOpenAIClient, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.sgl.Engine", + FakeEngine, + ) + monkeypatch.setattr( + "mmirage.core.process.processors.llm.llm_processor.AutoTokenizer.from_pretrained", + lambda *args, **kwargs: FakeTokenizer(), + ) + + metadata_base = tmp_path / "batch_receipts.jsonl" + llm_cfg = SGLangLLMConfig( + type="llm", + server_args=SGLangServerArgs(model_path="dummy-model"), + batch_provider=OpenAIBatchConfig( + enabled=True, + model="gpt-4.1-mini", + max_chunk_bytes=500, + max_requests_per_chunk=None, + metadata_output_path=str(metadata_base), + credentials={"api_key": "test-key"}, + metadata={"pipeline": "integration-test"}, + ), + ) + + mapper = MMIRAGEMapper( + processor_configs=[llm_cfg], + input_vars=[InputVar(name="text", key="text")], + output_vars=[ + LLMOutputVar( + name="answer", + type="llm", + prompt="{{ text }}", + output_type="plain", + ) + ], + ) + renderer = TemplateRenderer(output_schema={"answer": "{{ answer }}"}) + + data_path = Path(__file__).parent / "mock_data" / "data.jsonl" + dataset = load_dataset("json", data_files=str(data_path), split="train") + + def rewrite_batch(batch, mapper, renderer): + envs = mapper.rewrite_batch(batch) + return renderer.batch_render(envs) + + ds_out = dataset.map( + rewrite_batch, + batched=True, + batch_size=7, + fn_kwargs={"mapper": mapper, "renderer": renderer}, + load_from_cache_file=False, + ) + + # Explicit lifecycle flush required by the architecture. + mapper.finalize_processors() + + # 1) Multiple provider submissions prove byte-based chunking with carry-over. + assert captured["engine_init_calls"] == 0 + assert len(captured["file_uploads"]) > 1 + assert len(captured["batch_creates"]) > 1 + + # 2) Map output is placeholder-based and does not wait for completion. + answers = ds_out["answer"] + assert len(answers) == len(dataset) + assert all(isinstance(v, str) and v.startswith("__BATCH_SUBMITTED__:answer:") for v in answers) + + # 3) Metadata receipts are written and include both full_chunk and finalize flush reasons. + metadata_text_matches = sorted(tmp_path.glob("batch_receipts.text.*.jsonl")) + assert len(metadata_text_matches) == 1 + metadata_text_path = metadata_text_matches[0] + + records = [ + json.loads(line) + for line in metadata_text_path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + assert len(records) > 1 + + flush_reasons = {record["flush_reason"] for record in records} + assert "full_chunk" in flush_reasons + assert "finalize" in flush_reasons + + assert all(record["provider"] == "openai" for record in records) + assert all("custom_id_to_source_index" in record for record in records) diff --git a/tests/test_integration_receiver.py b/tests/test_integration_receiver.py new file mode 100644 index 0000000..7e44692 --- /dev/null +++ b/tests/test_integration_receiver.py @@ -0,0 +1,72 @@ +import json + +from mmirage.config.openai_batch import OpenAIBatchConfig + + +def test_integration_receiver_reads_receipt_and_writes_merged_output(tmp_path, monkeypatch): + from mmirage.core.process.batch.collector import _read_metadata_records, collect_and_merge + + metadata_path = tmp_path / "receipt.text.jsonl" + metadata_path.write_text( + "\n".join( + [ + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_a", + "custom_id_to_source_index": {"id_a": 1, "id_b": 0}, + } + ), + json.dumps( + { + "provider": "openai", + "provider_batch_id": "batch_b", + "custom_id_to_source_index": {"id_c": 2}, + } + ), + ] + ), + encoding="utf-8", + ) + + output_path = tmp_path / "merged.jsonl" + + class FakeAdapter: + def retrieve_results(self, provider_batch_id, config): + if provider_batch_id == "batch_a": + return [ + { + "custom_id": "id_a", + "generated_text": '{"question":"What is id_a?","answer":"one"}', + }, + { + "custom_id": "id_b", + "generated_text": '{"question":"What is id_b?","answer":"zero"}', + }, + ] + return [ + { + "custom_id": "id_c", + "generated_text": '{"question":"What is id_c?","answer":"two"}', + } + ] + + monkeypatch.setattr( + "mmirage.core.process.batch.collector.BatchAdapterFactory.from_config", + lambda config: FakeAdapter(), + ) + + records = _read_metadata_records(str(metadata_path)) + rows = collect_and_merge( + records=records, + provider_configs={"openai": OpenAIBatchConfig(credentials={"api_key": "test"})}, + output_path=str(output_path), + ) + + assert [r["source_index"] for r in rows] == [0, 1, 2] + assert [r["custom_id"] for r in rows] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in rows] == ["zero", "one", "two"] + + written = [json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines()] + assert [r["custom_id"] for r in written] == ["id_b", "id_a", "id_c"] + assert [r["conversations"][1]["content"] for r in written] == ["zero", "one", "two"] diff --git a/tests/test_openai_batch_adapter.py b/tests/test_openai_batch_adapter.py new file mode 100644 index 0000000..6b228d1 --- /dev/null +++ b/tests/test_openai_batch_adapter.py @@ -0,0 +1,550 @@ +import json +import base64 + +import pytest + +from mmirage.config.openai_batch import OpenAIBatchConfig +from mmirage.core.process.batch.adapter import BatchSubmissionResult +from mmirage.core.process.batch.registry import BatchAdapterFactory, BatchAdapterRegistry + + +def test_openai_build_request_matches_expected_structure(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + config = OpenAIBatchConfig(model="gpt-4.1-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0, + } + + request = adapter.build_request(custom_id="row-001", payload=payload, config=config) + + assert request == { + "custom_id": "row-001", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-4.1-mini", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0, + }, + } + + +def test_openai_build_request_injects_structured_output_format(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + config = OpenAIBatchConfig(model="gpt-4.1-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [{"role": "user", "content": "hello"}], + "expected_schema": ["question", "answer"], + } + + request = adapter.build_request(custom_id="row-002", payload=payload, config=config) + + assert "expected_schema" not in request["body"] + + assert request["body"]["response_format"] == { + "type": "json_schema", + "json_schema": { + "name": "structured_output", + "strict": True, + "schema": { + "type": "object", + "properties": { + "question": {"type": "string"}, + "answer": {"type": "string"}, + }, + "required": ["question", "answer"], + "additionalProperties": False, + }, + }, + } + + +def test_openai_build_request_converts_local_image_path_to_data_uri(tmp_path): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + image_bytes = b"\xff\xd8\xff\xe0testjpeg" + image_path = tmp_path / "sample.jpg" + image_path.write_bytes(image_bytes) + + config = OpenAIBatchConfig(model="gpt-4o-mini") + adapter = OpenAIBatchAdapter() + payload = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe image"}, + {"type": "image_url", "image_url": {"url": str(image_path)}}, + ], + } + ] + } + + request = adapter.build_request(custom_id="vision-1", payload=payload, config=config) + + url = request["body"]["messages"][0]["content"][1]["image_url"]["url"] + assert url.startswith("data:image/jpeg;base64,") + + encoded = url.split(",", 1)[1] + assert base64.b64decode(encoded) == image_bytes + + +def test_openai_estimate_request_bytes_matches_utf8_json_size(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + adapter = OpenAIBatchAdapter() + request = { + "custom_id": "accented", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"message": "caf\u00e9"}, + } + + expected = len( + json.dumps(request, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + ) + + assert adapter.estimate_request_bytes(request) == expected + + +def test_openai_submit_chunk_uses_mocked_openai_client(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeFiles: + def create(self, *, file, purpose): + file_name, file_obj = file + assert file_name == "batch_chunk-chunk-01.jsonl" + assert purpose == "batch" + file_content = file_obj.read().decode("utf-8") + captured["jsonl"] = file_content + + class _FileResp: + id = "file_123" + + return _FileResp() + + class FakeBatches: + def create(self, **kwargs): + captured["batch_create_kwargs"] = kwargs + + class _BatchResp: + id = "batch_123" + status = "validating" + endpoint = kwargs["endpoint"] + + return _BatchResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.files = FakeFiles() + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig( + model="gpt-4.1-mini", + completion_window="24h", + batch_endpoint="/v1/chat/completions", + metadata={"pipeline": "unit"}, + credentials={"api_key": "test-key"}, + ) + adapter = OpenAIBatchAdapter() + requests = [ + adapter.build_request( + custom_id="r1", + payload={"messages": [{"role": "user", "content": "Hi"}]}, + config=config, + ) + ] + + raw_result = adapter.submit_chunk(chunk_id="chunk-01", requests=requests, config=config) + + assert captured["client_kwargs"]["api_key"] == "test-key" + assert captured["batch_create_kwargs"] == { + "input_file_id": "file_123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "metadata": {"pipeline": "unit", "chunk_id": "chunk-01"}, + } + assert raw_result["id"] == "batch_123" + assert raw_result["status"] == "validating" + assert raw_result["input_file_id"] == "file_123" + + jsonl_lines = [line for line in captured["jsonl"].split("\n") if line.strip()] + assert len(jsonl_lines) == 1 + assert json.loads(jsonl_lines[0]) == requests[0] + + +def test_openai_parse_submission_result_normalizes_payload(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + adapter = OpenAIBatchAdapter() + raw = { + "id": "batch_123", + "status": "in_progress", + "input_file_id": "file_123", + } + + result = adapter.parse_submission_result(raw_result=raw) + + assert isinstance(result, BatchSubmissionResult) + assert result.provider_batch_id == "batch_123" + assert result.status == "in_progress" + assert result.raw_response == raw + + +def test_factory_resolves_openai_adapter_from_registry(): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + BatchAdapterRegistry.clear() + config = OpenAIBatchConfig(model="gpt-4.1-mini", credentials={"api_key": "key"}) + + adapter = BatchAdapterFactory.from_config(config) + + assert isinstance(adapter, OpenAIBatchAdapter) + + +def test_openai_check_batch_status_uses_mocked_openai_client(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + captured["retrieved_id"] = provider_batch_id + + class _RetrieveResp: + id = provider_batch_id + status = "completed" + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig( + credentials={"api_key": "test-key"}, + base_url="https://example.test/v1", + ) + adapter = OpenAIBatchAdapter() + + result = adapter.check_batch_status(provider_batch_id="batch_456", config=config) + + assert captured["client_kwargs"] == { + "api_key": "test-key", + "base_url": "https://example.test/v1", + } + assert captured["retrieved_id"] == "batch_456" + assert isinstance(result, BatchSubmissionResult) + assert result.provider_batch_id == "batch_456" + assert result.status == "completed" + + +def test_openai_check_batch_status_falls_back_to_env_api_key(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + monkeypatch.setenv("OPENAI_API_KEY", "env-test-key") + + config = OpenAIBatchConfig(credentials={}) + adapter = OpenAIBatchAdapter() + + result = adapter.check_batch_status(provider_batch_id="batch_env", config=config) + + assert captured["client_kwargs"]["api_key"] == "env-test-key" + assert result.provider_batch_id == "batch_env" + + +def test_openai_check_batch_status_raises_when_no_api_key(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + + config = OpenAIBatchConfig(credentials={}) + adapter = OpenAIBatchAdapter() + + with pytest.raises(ValueError, match="OpenAI API key is missing"): + adapter.check_batch_status(provider_batch_id="batch_missing", config=config) + + +def test_openai_retrieve_results_downloads_and_parses_jsonl(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + captured = {} + + class FakeBatches: + def retrieve(self, provider_batch_id): + captured["retrieved_id"] = provider_batch_id + + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_1" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + captured["output_file_id"] = output_file_id + + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"text":"A"}}}\n' + '{"custom_id":"c2","response":{"body":{"text":"B"}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + captured["client_kwargs"] = kwargs + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + assert captured["retrieved_id"] == "batch_abc" + assert captured["output_file_id"] == "file_output_1" + assert len(rows) == 2 + assert rows[0]["custom_id"] == "c1" + assert rows[1]["custom_id"] == "c2" + assert rows[0]["response"]["body"]["text"] == "A" + assert rows[1]["response"]["body"]["text"] == "B" + assert rows[0]["generated_text"] == "A" + assert rows[1]["generated_text"] == "B" + + +def test_openai_retrieve_results_prefers_message_content(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_choices" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"choices":[' + '{"message":{"content":"{\\"question\\":\\"Q\\",\\"answer\\":\\"A\\"}"}}' + ']}}}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_choices", config=config) + + assert rows[0]["custom_id"] == "c1" + assert rows[0]["generated_text"] == '{"question":"Q","answer":"A"}' + + +def test_openai_retrieve_results_normalizes_error_rows(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = "file_output_error" + + return _RetrieveResp() + + class FakeFiles: + def content(self, output_file_id): + class _ContentResp: + text = ( + '{"id":"batch_req_1","custom_id":"formatted_answer:text:50",' + '"response":{"status_code":400,"request_id":"req_1",' + '"body":{"error":{"message":"Unrecognized request argument supplied: expected_schema",' + '"type":"invalid_request_error","param":null,"code":null}}},"error":null}\n' + ) + + return _ContentResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + self.files = FakeFiles() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_error", config=config) + + assert rows == [ + { + "id": "batch_req_1", + "custom_id": "formatted_answer:text:50", + "response": { + "status_code": 400, + "request_id": "req_1", + "body": { + "error": { + "message": "Unrecognized request argument supplied: expected_schema", + "type": "invalid_request_error", + "param": None, + "code": None, + } + }, + }, + "error": None, + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + ] + + +def test_openai_retrieve_results_raises_if_batch_not_completed(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "in_progress" + output_file_id = None + + return _RetrieveResp() + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + + class _Files: + def content(self, output_file_id): + raise AssertionError("content() should not be called when batch is not completed") + + self.files = _Files() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + with pytest.raises(ValueError, match="not completed"): + adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + +def test_openai_retrieve_results_uses_error_file_when_output_missing(monkeypatch): + from mmirage.core.process.batch.openai_adapter import OpenAIBatchAdapter + + class FakeBatches: + def retrieve(self, provider_batch_id): + class _RetrieveResp: + id = provider_batch_id + status = "completed" + output_file_id = None + error_file_id = "file_error_1" + + return _RetrieveResp() + + captured = {} + + class FakeClient: + def __init__(self, **kwargs): + self.batches = FakeBatches() + + class _Files: + def content(self, output_file_id): + captured["output_file_id"] = output_file_id + + class _ContentResp: + text = ( + '{"custom_id":"c1","response":{"body":{"error":{' + '"message":"Unrecognized request argument supplied: expected_schema"}}}}\n' + ) + + return _ContentResp() + + self.files = _Files() + + monkeypatch.setattr( + "mmirage.core.process.batch.openai_adapter.OpenAI", + FakeClient, + ) + + config = OpenAIBatchConfig(credentials={"api_key": "test-key"}) + adapter = OpenAIBatchAdapter() + + rows = adapter.retrieve_results(provider_batch_id="batch_abc", config=config) + + assert captured["output_file_id"] == "file_error_1" + assert rows == [ + { + "custom_id": "c1", + "response": {"body": {"error": {"message": "Unrecognized request argument supplied: expected_schema"}}}, + "status": "error", + "error_message": "Unrecognized request argument supplied: expected_schema", + } + ]