diff --git a/README.md b/README.md index 92357ae..54e5616 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ For testing and scripts that make use of the library, it is advised to create a - Image inputs for multimodal processing - Parallelizable with multi-node support - The training pipeline uses distributed inference with sharding -- Support a variety of LLMs and VLMs (Vision-Language Models) +- Support a variety of LLMs, VLMs (Vision-Language Models), and image generation models - Support any dataset schemas (configurable with the YAML format) - The ability to either output a JSON (or any other structured format) or plain text - Modular architecture with pluggable processors, loaders, and writers @@ -123,7 +123,7 @@ execution_params: Configuration explanation: -- `processors`: List of processor configurations. Currently supports `llm` type for LLM-based generation. +- `processors`: List of processor configurations. Supports `llm` (text/VLM generation) and `image_gen` (text-to-image generation). - `loading_params`: Parameters for loading and sharding datasets. - `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/.cache/MMIRAGE/state_dir`. - `datasets`: List of dataset configurations with path, type, and output directory. @@ -191,6 +191,66 @@ execution_params: retry: false ``` +### Image generation: Text-to-image pipeline + +MMIRAGE also supports image generation with Diffusers models: + +```yaml +processors: + - type: image_gen + pipeline_args: + model_path: stable-diffusion-v1-5/stable-diffusion-v1-5 + torch_dtype: float16 + device: auto + enable_attention_slicing: true + default_sampling_params: + num_inference_steps: 20 + guidance_scale: 7.5 + output_dir: /path/to/generated/images + file_format: png + +loading_params: + state_dir: /path/to/state/dir + datasets: + - path: /path/to/prompts.jsonl + type: JSONL + output_dir: /path/to/output/shards + num_shards: 1 + shard_id: 0 + batch_size: 8 + +processing_params: + inputs: + - name: prompt_text + key: text + + outputs: + - name: generated_image + type: image_gen + output_mode: path # "path" or "pil" + filename_template: "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" + width: 512 + height: 512 + prompt: | + Create an illustration of: + {{ prompt_text }} + + remove_columns: false + output_schema: + text: "{{ prompt_text }}" + image: "{{ generated_image }}" + +execution_params: + mode: local + retry: false +``` + +Install optional image generation dependencies before running this config: + +```bash +pip install -e .[image_gen] +``` + Key multimodal features: - `chat_template`: Specify the VLM chat template (e.g., `qwen2-vl`) - `type: image`: Mark input variables as images diff --git a/configs/config_image_gen_diffusers.yaml b/configs/config_image_gen_diffusers.yaml new file mode 100644 index 0000000..7e184d3 --- /dev/null +++ b/configs/config_image_gen_diffusers.yaml @@ -0,0 +1,121 @@ +# Local Diffusers image generation example +# Runs an in-process Diffusers pipeline from a local model path. + +processors: + - type: image_gen + backend: diffusers + pipeline_args: + model_path: stable-diffusion-v1-5/stable-diffusion-v1-5 + device: auto # auto: multi-GPU via device_map, single GPU, or CPU + torch_dtype: float16 + enable_attention_slicing: false + default_sampling_params: + num_inference_steps: 30 + guidance_scale: 4.0 + parallel_inference: true + parallel_chunk_size: 4 # samples per batched GPU call + output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/generated_images + file_format: png + +loading_params: + state_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/_pipeline_state + datasets: + - path: /users/qchapp/meditron/MIRAGE/tests/mock_data_image_gen/data.jsonl + type: JSONL + output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen + num_shards: 4 # set by Slurm array job + shard_id: ${SLURM_ARRAY_TASK_ID} # set by Slurm array job + batch_size: 64 + +processing_params: + inputs: + - name: text + key: caption + + outputs: + - name: generated_image + type: image_gen + output_mode: path + # __sample_index is shard-local; combine with __shard_id for global uniqueness + filename_template: "img_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" + width: 1024 + height: 1024 + seed: 42 # shard-aware: effective = 42 + shard_id * 1_000_000_000 + sample_index + prompt: | + A photorealistic image of: {{ text }} + + remove_columns: false + output_schema: + caption: "{{ text }}" + image: "{{ generated_image }}" + +execution_params: + # Execution mode: "local" or "slurm" + # - local: Run directly on this machine + # - slurm: Submit jobs to SLURM cluster + mode: slurm + + # Whether the canonical `run` command should automatically retry failed shards. + # - false: submit one run only + # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion + retry: false + + # Maximum number of times to retry a failed shard (default: 3) + max_retries: 3 + + # ========================================================================== + # SLURM CONFIGURATION (only used when mode: slurm) + # ========================================================================== + + # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) + account: a127 + + # SLURM job name (default: "mmirage-sharded") + job_name: mmirage-sharded + + # Optional SLURM reservation name (leave blank or omit to not use) + # reservation: "sai-a127" + + # Number of nodes (default: 1) + nodes: 1 + + # Number of tasks per node (default: 1) + ntasks_per_node: 1 + + # Number of GPUs per node (default: 4) + gpus: 4 + + # Number of CPUs per task (default: 288) + cpus_per_task: 288 + + # Job time limit in HH:MM:SS format (default: "11:59:59") + time_limit: "11:59:59" + + # ========================================================================== + # PATH CONFIGURATION + # ========================================================================== + # These support environment variables ($VAR or ${VAR}) and home directory (~) + + # Project root directory (used as base for relative paths) + # If not set, uses current working directory + # project_root: "/path/to/project" + + # Directory for SLURM output and error files (default: ~/reports) + report_dir: "/users/${USER}/reports" + + # HuggingFace cache directory (default: ~/hf) + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + + # EDF environment file path for cluster-specific setup + edf_env: "/users/${USER}/.edf/sglang.toml" + + # ========================================================================== + # JOB MONITORING (for "submit" and retry orchestration) + # ========================================================================== + + # Seconds to wait between checking job status (default: 30) + poll_interval_seconds: 30 + + # Seconds to wait after job completes before checking results (default: 60) + # This allows filesystem to settle on distributed systems + settle_time_seconds: 60 diff --git a/configs/config_image_gen_sglang.yaml b/configs/config_image_gen_sglang.yaml new file mode 100644 index 0000000..1e7e417 --- /dev/null +++ b/configs/config_image_gen_sglang.yaml @@ -0,0 +1,142 @@ +# SGLang Diffusion server image generation example +# +# Two launch modes are available: +# +# launch_mode: managed (recommended) +# MMIRAGE automatically starts a SGLang server on each worker node, +# waits until it is ready, runs the pipeline, and shuts it down afterwards. +# No manual server management is needed. +# +# launch_mode: external +# You are responsible for starting the SGLang server before the pipeline +# runs. Use this if you want to reuse a long-running server or need +# fine-grained control over server startup. +# +# MMIRAGE handles all dataset sharding, prompt rendering, filename rendering, +# and result saving. The SGLang server is only responsible for generating +# image pixels. + +processors: + - type: image_gen + backend: sglang + sglang: + launch_mode: managed # MMIRAGE starts/stops the server automatically + model_path: stable-diffusion-v1-5/stable-diffusion-v1-5 + port: 30010 + num_gpus: 4 # passed as --tp to sglang.launch_server + # dtype: float16 # optional: --dtype flag + startup_timeout_seconds: 180 # seconds to wait for the server to become ready + # extra_server_args: # any additional --flags for sglang.launch_server + # - "--mem-fraction-static" + # - "0.9" + api_key: EMPTY # unauthenticated local server + timeout_seconds: 900 + default_sampling_params: + num_inference_steps: 30 + guidance_scale: 4.0 + parallel_inference: true + parallel_chunk_size: 4 # concurrent requests per chunk (sequential per sample inside the backend) + output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/generated_images + file_format: png + +loading_params: + state_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/_pipeline_state + datasets: + - path: /users/qchapp/meditron/MIRAGE/tests/mock_data_image_gen/data.jsonl + type: JSONL + output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen + num_shards: 4 # each Slurm task starts its own server on localhost + shard_id: ${SLURM_ARRAY_TASK_ID} + batch_size: 64 + +processing_params: + inputs: + - name: text + key: caption + + outputs: + - name: generated_image + type: image_gen + output_mode: path + filename_template: "img_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" + width: 1024 + height: 1024 + seed: 42 # shard-aware: effective seed = 42 + shard_id * 1_000_000_000 + sample_index + prompt: | + A photorealistic image of: {{ text }} + + remove_columns: false + output_schema: + caption: "{{ text }}" + image: "{{ generated_image }}" + +execution_params: + # Execution mode: "local" or "slurm" + # - local: Run directly on this machine + # - slurm: Submit jobs to SLURM cluster + mode: slurm + + # Whether the canonical `run` command should automatically retry failed shards. + # - false: submit one run only + # - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion + retry: false + + # Maximum number of times to retry a failed shard (default: 3) + max_retries: 3 + + # ========================================================================== + # SLURM CONFIGURATION (only used when mode: slurm) + # ========================================================================== + + # HPC account/partition to charge jobs to (REQUIRED for SLURM mode) + account: a127 + + # SLURM job name (default: "mmirage-sharded") + job_name: mmirage-sharded + + # Optional SLURM reservation name (leave blank or omit to not use) + # reservation: "sai-a127" + + # Number of nodes (default: 1) + nodes: 1 + + # Number of tasks per node (default: 1) + ntasks_per_node: 1 + + # Number of GPUs per node (default: 4) + gpus: 4 + + # Number of CPUs per task (default: 288) + cpus_per_task: 288 + + # Job time limit in HH:MM:SS format (default: "11:59:59") + time_limit: "11:59:59" + + # ========================================================================== + # PATH CONFIGURATION + # ========================================================================== + # These support environment variables ($VAR or ${VAR}) and home directory (~) + + # Project root directory (used as base for relative paths) + # If not set, uses current working directory + # project_root: "/path/to/project" + + # Directory for SLURM output and error files (default: ~/reports) + report_dir: "/users/${USER}/reports" + + # HuggingFace cache directory (default: ~/hf) + hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf" + + # EDF environment file path for cluster-specific setup + edf_env: "/users/${USER}/.edf/sglang.toml" + + # ========================================================================== + # JOB MONITORING (for "submit" and retry orchestration) + # ========================================================================== + + # Seconds to wait between checking job status (default: 30) + poll_interval_seconds: 30 + + # Seconds to wait after job completes before checking results (default: 60) + # This allows filesystem to settle on distributed systems + settle_time_seconds: 60 diff --git a/configs/config_mock_image_gen_local.yaml b/configs/config_mock_image_gen_local.yaml new file mode 100644 index 0000000..b778547 --- /dev/null +++ b/configs/config_mock_image_gen_local.yaml @@ -0,0 +1,54 @@ +processors: + - type: image_gen + backend: diffusers + pipeline_args: + model_path: stable-diffusion-v1-5/stable-diffusion-v1-5 + torch_dtype: float16 + device: auto + enable_attention_slicing: false + default_sampling_params: + num_inference_steps: 20 + guidance_scale: 7.5 + parallel_inference: true + parallel_chunk_size: 4 + output_dir: tests/output/image_gen/generated_images + file_format: png + +loading_params: + state_dir: tests/output/image_gen/_pipeline_state + datasets: + - path: tests/mock_data_image_gen/data.jsonl + type: JSONL + output_dir: tests/output/image_gen + + num_shards: 1 + shard_id: 0 + batch_size: 4 + +processing_params: + inputs: + - name: text + key: text + + outputs: + - name: generated_image + type: image_gen + output_mode: path + filename_template: "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" + width: 512 + height: 512 + prompt: | + Create a clean and detailed illustration for: + {{ text }} + + remove_columns: false + output_schema: + prompt_source: "{{ text }}" + image: "{{ generated_image }}" + +execution_params: + mode: local + retry: false + merge: false + report_dir: ~/reports + hf_home: ~/hf diff --git a/pyproject.toml b/pyproject.toml index 5804d4e..eb88a93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,11 @@ dev = [ "ipykernel", "pytest", ] +image_gen = [ + "diffusers>=0.31.0", + "accelerate>=0.33.0", + "safetensors>=0.4.4", +] [project.scripts] mmirage = "mmirage.cli:main" diff --git a/src/mmirage/config/utils.py b/src/mmirage/config/utils.py index e69c2d3..53e9791 100644 --- a/src/mmirage/config/utils.py +++ b/src/mmirage/config/utils.py @@ -6,6 +6,7 @@ import os from mmirage.config.config import MMirageConfig +from mmirage.core.process.processors.image_gen.config import ImageOutputMode from mmirage.core.process.base import BaseProcessorConfig, ProcessorRegistry, OutputVar from mmirage.core.loader.base import BaseDataLoaderConfig, DataLoaderRegistry @@ -15,6 +16,7 @@ # to construct config/output-var objects from YAML without importing heavy # processor implementations (e.g. torch/transformers). import mmirage.core.process.processors.llm.config # noqa: F401 +import mmirage.core.process.processors.image_gen.config # noqa: F401 import mmirage.core.loader.jsonl # noqa: F401 import mmirage.core.loader.local_hf # noqa: F401 @@ -98,6 +100,11 @@ def expand_env_vars(obj: EnvValue) -> EnvValue: return os.path.expandvars(obj) else: return obj + + def image_output_mode_hook(value: Any) -> ImageOutputMode: + if isinstance(value, ImageOutputMode): + return value + return ImageOutputMode(value) def processor_config_hook(data: Dict[str, Any]) -> BaseProcessorConfig: clz = ProcessorRegistry.get_config_cls(data["type"]) @@ -114,6 +121,7 @@ def output_var_hook(data: Dict[str, Any]) -> OutputVar: cfg = expand_env_vars(cfg) config = Config( type_hooks={ + ImageOutputMode: image_output_mode_hook, BaseProcessorConfig: processor_config_hook, BaseDataLoaderConfig: loader_config_hook, OutputVar: output_var_hook, diff --git a/src/mmirage/core/process/base.py b/src/mmirage/core/process/base.py index 988bae7..44a28ef 100644 --- a/src/mmirage/core/process/base.py +++ b/src/mmirage/core/process/base.py @@ -37,14 +37,28 @@ class BaseProcessor(abc.ABC, Generic[C]): config: Configuration object for this processor. """ - def __init__(self, config: BaseProcessorConfig) -> None: + def __init__(self, config: BaseProcessorConfig, shard_id: int = 0, **kwargs) -> None: """Initialize the processor with configuration. Args: config: Configuration object for this processor. + shard_id: Optional shard identifier accepted for compatibility + with callers that forward it during processor construction. + **kwargs: Additional keyword arguments. Any unexpected keyword + arguments will raise ``TypeError``. + + Raises: + TypeError: If unexpected keyword arguments are provided. """ + if kwargs: + unexpected_args = ", ".join(sorted(kwargs)) + raise TypeError( + f"Unexpected keyword argument(s) for " + f"{self.__class__.__name__}: {unexpected_args}" + ) super().__init__() self.config = config + self.shard_id = shard_id @abc.abstractmethod def batch_process_sample( @@ -64,6 +78,13 @@ def batch_process_sample( """ raise NotImplementedError() + def shutdown(self) -> None: + """Release any resources held by this processor. + + Override in subclasses that hold GPU memory, open file handles, or + network connections. The default implementation is a no-op. + """ + class ProcessorRegistry: """Registry for managing and accessing available processors. @@ -84,7 +105,10 @@ class ProcessorRegistry: # Import processor implementations lazily because they may depend on heavy # libraries (torch/transformers). Config/output-var types are registered via # mmirage.config.utils importing the relevant config modules. - _lazy_processor_imports = {"llm": "mmirage.core.process.processors.llm.llm_processor"} + _lazy_processor_imports = { + "llm": "mmirage.core.process.processors.llm.llm_processor", + "image_gen": "mmirage.core.process.processors.image_gen.image_gen_processor", + } @classmethod def register_types( diff --git a/src/mmirage/core/process/mapper.py b/src/mmirage/core/process/mapper.py index 5310150..e3f64ef 100644 --- a/src/mmirage/core/process/mapper.py +++ b/src/mmirage/core/process/mapper.py @@ -29,6 +29,7 @@ def __init__( processor_configs: List[BaseProcessorConfig], input_vars: List[InputVar], output_vars: List[OutputVar], + shard_id: int = 0, ) -> None: """Initialize the MMIRAGE mapper. @@ -36,6 +37,7 @@ def __init__( processor_configs: List of processor configurations. input_vars: List of input variable definitions. output_vars: List of output variable definitions. + shard_id: Shard index for this worker, forwarded to processors. """ self.processors: Dict[str, BaseProcessor] = dict() self.input_vars = input_vars @@ -45,7 +47,7 @@ def __init__( processor_cls = AutoProcessor.from_name(config.type) logger.info(f"✅ Successfully loaded processor of type {config.type}") - self.processors[config.type] = processor_cls(config) + self.processors[config.type] = processor_cls(config, shard_id=shard_id) def validate_vars(self) -> bool: """Validate that all output variables are computable. @@ -103,3 +105,8 @@ def rewrite_batch( ) return batch_environment + + def shutdown(self) -> None: + """Shut down all processors and release their resources.""" + for processor in self.processors.values(): + processor.shutdown() diff --git a/src/mmirage/core/process/processors/image_gen/__init__.py b/src/mmirage/core/process/processors/image_gen/__init__.py new file mode 100644 index 0000000..58c65fa --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/__init__.py @@ -0,0 +1,6 @@ +"""Image generation processor implementation. + +This module provides a dedicated processor for text-to-image generation tasks +using Diffusers pipelines. It can emit either saved image paths or in-memory +PIL images. +""" diff --git a/src/mmirage/core/process/processors/image_gen/backends/__init__.py b/src/mmirage/core/process/processors/image_gen/backends/__init__.py new file mode 100644 index 0000000..ea013dd --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/backends/__init__.py @@ -0,0 +1,5 @@ +"""Image generation backends for MMIRAGE.""" + +from mmirage.core.process.processors.image_gen.backends.base import ImageGenerationBackend + +__all__ = ["ImageGenerationBackend"] diff --git a/src/mmirage/core/process/processors/image_gen/backends/base.py b/src/mmirage/core/process/processors/image_gen/backends/base.py new file mode 100644 index 0000000..138c7d9 --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/backends/base.py @@ -0,0 +1,51 @@ +"""Image generation backend protocol for MMIRAGE.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +try: + from typing import Protocol, runtime_checkable +except ImportError: # pragma: no cover + from typing_extensions import Protocol, runtime_checkable # type: ignore + + +@runtime_checkable +class ImageGenerationBackend(Protocol): + """Protocol for pluggable image generation backends. + + All backends receive pre-rendered prompts and pre-computed per-sample seeds + from the processor. The processor handles all Jinja template rendering, + filename generation, and result bookkeeping; the backend is responsible + only for turning prompts + params into PIL images. + """ + + def generate_batch( + self, + prompts: List[str], + negative_prompts: Optional[List[Optional[str]]], + params: Dict[str, Any], + seeds: List[Optional[int]], + ) -> List[Any]: + """Generate one image per prompt. + + Args: + prompts: Positive prompt strings, one per sample. + negative_prompts: Optional list of negative prompts aligned with + ``prompts``. ``None`` means no negative prompts at all; + individual ``None`` elements mean no negative prompt for that + sample. + params: Shared generation kwargs (width, height, + num_inference_steps, guidance_scale, …). + seeds: Per-sample integer seeds for deterministic generation, or + ``None`` elements for unseeded samples. The list is always + the same length as ``prompts``. + + Returns: + List of ``PIL.Image`` objects, one per prompt, in the same order. + """ + ... + + def shutdown(self) -> None: + """Release any resources held by the backend.""" + ... diff --git a/src/mmirage/core/process/processors/image_gen/backends/diffusers_backend.py b/src/mmirage/core/process/processors/image_gen/backends/diffusers_backend.py new file mode 100644 index 0000000..0364f4f --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/backends/diffusers_backend.py @@ -0,0 +1,274 @@ +"""Diffusers-based image generation backend.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from mmirage.core.process.processors.image_gen.config import DiffusersPipelineArgs + +if TYPE_CHECKING: # pragma: no cover + import torch + from diffusers import DiffusionPipeline + +logger = logging.getLogger(__name__) + + +class DiffusersImageBackend: + """Image generation backend using an in-process Diffusers pipeline. + + This backend loads a local or cached Diffusers pipeline once at + construction time and keeps it in memory for repeated batched generation. + """ + + def __init__(self, pipeline_args: DiffusersPipelineArgs) -> None: + try: + import torch + from diffusers import DiffusionPipeline # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover + raise RuntimeError( + "diffusers backend requires optional dependencies. " + "Install with: pip install -e .[image_gen]" + ) from exc + + self._torch = torch + self._pipeline, self._generator_device = self._build_pipeline( + DiffusionPipeline, + pipeline_args, + ) + + # ------------------------------------------------------------------ + # Pipeline construction + # ------------------------------------------------------------------ + + def _build_pipeline( + self, + pipeline_cls: type[DiffusionPipeline], + args: DiffusersPipelineArgs, + ) -> Tuple[DiffusionPipeline, str]: + """Load and configure the Diffusers pipeline. + + Returns: + ``(pipeline, generator_device)`` where ``generator_device`` is the + device string used for ``torch.Generator`` objects. + + Notes: + ``device_map`` is only local-process model/component placement. It + is not Slurm multi-node distribution. + """ + load_kwargs = self._build_load_kwargs(args) + + placement_device = args.device + generator_device = args.device + use_device_map = False + + if placement_device == "auto": + placement_device, generator_device, use_device_map = self._resolve_auto_device(args) + + if use_device_map: + device_map = getattr(args, "device_map", None) or "balanced" + load_kwargs["device_map"] = device_map + + logger.info( + "device='auto': using Diffusers device_map=%r across %d visible GPUs", + device_map, + self._torch.cuda.device_count(), + ) + + pipeline = pipeline_cls.from_pretrained(args.model_path, **load_kwargs) + + if not use_device_map: + pipeline = pipeline.to(placement_device) + + if getattr(args, "enable_attention_slicing", False): + if hasattr(pipeline, "enable_attention_slicing"): + pipeline.enable_attention_slicing() + logger.info("Enabled Diffusers attention slicing.") + else: + logger.warning( + "enable_attention_slicing=True was requested, but this pipeline " + "does not expose enable_attention_slicing()." + ) + + return pipeline, generator_device + + def _build_load_kwargs(self, args: DiffusersPipelineArgs) -> Dict[str, Any]: + """Build keyword arguments for ``DiffusionPipeline.from_pretrained``.""" + load_kwargs: Dict[str, Any] = {} + + torch_dtype = self._parse_torch_dtype(getattr(args, "torch_dtype", "auto")) + if torch_dtype is not None: + load_kwargs["torch_dtype"] = torch_dtype + + optional_fields = { + "revision": "revision", + "cache_dir": "cache_dir", + "custom_pipeline": "custom_pipeline", + "variant": "variant", + } + for attr_name, kwarg_name in optional_fields.items(): + value = getattr(args, attr_name, None) + if value: + load_kwargs[kwarg_name] = value + + boolean_fields = { + "local_files_only": "local_files_only", + "trust_remote_code": "trust_remote_code", + } + for attr_name, kwarg_name in boolean_fields.items(): + value = getattr(args, attr_name, False) + if value: + load_kwargs[kwarg_name] = True + + return load_kwargs + + def _resolve_auto_device(self, args: DiffusersPipelineArgs) -> Tuple[str, str, bool]: + """Resolve ``device='auto'`` into placement/generator choices. + + Returns: + ``(placement_device, generator_device, use_device_map)``. + """ + if not self._torch.cuda.is_available(): + logger.info("device='auto': resolved to 'cpu' because CUDA is unavailable.") + return "cpu", "cpu", False + + num_gpus = self._torch.cuda.device_count() + if num_gpus <= 0: + logger.info("device='auto': resolved to 'cpu' because no CUDA devices are visible.") + return "cpu", "cpu", False + + if num_gpus == 1: + logger.info("device='auto': resolved to 'cuda' because one GPU is visible.") + return "cuda", "cuda", False + + # Multiple GPUs visible to this process. This is local model/component + # placement, not distributed Slurm execution. + # + # CPU generators are safest when the model is split across multiple + # devices and also improve reproducibility across CPU/GPU setups. + return "cpu", "cpu", True + + def _parse_torch_dtype(self, dtype: str) -> Optional[torch.dtype]: + """Convert a dtype string into a ``torch.dtype``.""" + if not dtype: + return None + + dtype_key = dtype.lower() + if dtype_key == "auto": + return None + + mapping = { + "float16": self._torch.float16, + "fp16": self._torch.float16, + "bfloat16": self._torch.bfloat16, + "bf16": self._torch.bfloat16, + "float32": self._torch.float32, + "fp32": self._torch.float32, + } + + if dtype_key not in mapping: + raise ValueError( + f"Unsupported torch_dtype={dtype!r}. " + "Use one of: auto, float16, bfloat16, float32." + ) + + return mapping[dtype_key] + + # ------------------------------------------------------------------ + # Generation helpers + # ------------------------------------------------------------------ + + def _build_generators(self, seeds: List[Optional[int]]) -> Optional[Any]: + """Build one or more ``torch.Generator`` objects from seeds. + + Returns: + ``None`` when all seeds are ``None``. + A single generator for one prompt. + A list of generators for batched prompts. + """ + if not seeds or all(seed is None for seed in seeds): + return None + + if any(seed is None for seed in seeds): + raise ValueError( + "Diffusers batched generation requires seeds to be either all set " + "or all None. Received a mixed seed list." + ) + + generators = [] + for seed in seeds: + generator = self._torch.Generator(device=self._generator_device) + generator.manual_seed(int(seed)) + generators.append(generator) + + return generators[0] if len(generators) == 1 else generators + + @staticmethod + def _normalize_negative_prompts( + prompts: List[str], + negative_prompts: Optional[List[Optional[str]]], + ) -> Optional[List[str]]: + """Normalize optional negative prompts for Diffusers.""" + if negative_prompts is None: + return None + + if len(negative_prompts) != len(prompts): + raise ValueError( + f"Expected {len(prompts)} negative prompts, got {len(negative_prompts)}" + ) + + return [negative_prompt or "" for negative_prompt in negative_prompts] + + # ------------------------------------------------------------------ + # Backend interface + # ------------------------------------------------------------------ + + def generate_batch( + self, + prompts: List[str], + negative_prompts: Optional[List[Optional[str]]], + params: Dict[str, Any], + seeds: List[Optional[int]], + ) -> List[Any]: + """Generate images via the in-process Diffusers pipeline. + + The pipeline is invoked once with a list of prompts, allowing Diffusers + to perform batched generation on the selected local device placement. + """ + if not prompts: + return [] + + if seeds and len(seeds) != len(prompts): + raise ValueError(f"Expected {len(prompts)} seeds, got {len(seeds)}") + + call_kwargs: Dict[str, Any] = dict(params) + call_kwargs["prompt"] = prompts + + normalized_negative_prompts = self._normalize_negative_prompts( + prompts, + negative_prompts, + ) + if normalized_negative_prompts is not None: + call_kwargs["negative_prompt"] = normalized_negative_prompts + + generators = self._build_generators(seeds) + if generators is not None: + call_kwargs["generator"] = generators + + output = self._pipeline(**call_kwargs) + images = output.images + + if len(images) != len(prompts): + raise RuntimeError( + f"Expected {len(prompts)} images from Diffusers pipeline, got {len(images)}" + ) + + return images + + def shutdown(self) -> None: + """Release the pipeline reference and free CUDA cache.""" + self._pipeline = None + + if self._torch.cuda.is_available(): + self._torch.cuda.empty_cache() + logger.info("Released Diffusers pipeline and emptied CUDA cache.") \ No newline at end of file diff --git a/src/mmirage/core/process/processors/image_gen/backends/sglang_backend.py b/src/mmirage/core/process/processors/image_gen/backends/sglang_backend.py new file mode 100644 index 0000000..e41a457 --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/backends/sglang_backend.py @@ -0,0 +1,492 @@ +"""SGLang Diffusion server image generation backend.""" + +from __future__ import annotations + +import base64 +import binascii +import io +import json +import logging +import os +import subprocess +import sys +import time +import urllib.error +import urllib.request +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +try: + from PIL import Image as PILImage +except ImportError: # pragma: no cover + PILImage = None # type: ignore + + +class SGLangImageBackend: + """Image generation backend that calls a local SGLang Diffusion server. + + Supports two lifecycle modes controlled by ``SGLangBackendConfig.launch_mode``: + + - ``external``: connects to an already-running server. + - ``managed``: spawns ``python -m sglang.launch_server`` as a subprocess, + polls until the server is ready, and terminates it on :meth:`shutdown`. + """ + + def __init__( + self, + base_url: str, + api_key: str = "EMPTY", + timeout_seconds: int = 900, + request_model: Optional[str] = None, + model_path: Optional[str] = None, + validate_server: bool = True, + max_concurrent_requests: int = 1, + # managed-mode args + _managed_process: Optional[subprocess.Popen] = None, + ) -> None: + """Initialize the SGLang HTTP backend. + + Args: + base_url: Server base URL. Both of these are accepted: + ``http://127.0.0.1:30010`` and + ``http://127.0.0.1:30010/v1``. + api_key: API key sent as ``Authorization: Bearer ...``. + Use ``"EMPTY"`` for local unauthenticated servers. + timeout_seconds: Per-request HTTP timeout. + request_model: Optional model field to include in image requests. + Most single-model SGLang deployments do not need this. + model_path: Backward-compatible alias for ``request_model``. + The actual local model path should normally be supplied when + launching ``sglang serve``, not per request. + validate_server: Whether to check server reachability at init time. + max_concurrent_requests: Number of concurrent HTTP image requests + issued from ``generate_batch``. Defaults to 1 for conservative + server behavior. + """ + if PILImage is None: # pragma: no cover + raise RuntimeError( + "sglang backend requires Pillow. " + "Install with: pip install -e .[image_gen]" + ) + + if not base_url: + raise ValueError("SGLang base_url must be non-empty.") + + self._server_root_url, self._api_base_url = self._normalize_base_url(base_url) + self._api_key = api_key + self._timeout = timeout_seconds + self._request_model = request_model or model_path + self._max_concurrent_requests = max(1, int(max_concurrent_requests)) + self._managed_process = _managed_process + + if validate_server: + self._validate_server() + + # ------------------------------------------------------------------ + # Factory: managed server + # ------------------------------------------------------------------ + + @classmethod + def from_managed_config( + cls, + model_path: str, + port: int = 30010, + num_gpus: int = 1, + dtype: Optional[str] = None, + api_key: str = "EMPTY", + timeout_seconds: int = 900, + startup_timeout_seconds: int = 120, + extra_server_args: Optional[List[str]] = None, + max_concurrent_requests: int = 1, + ) -> SGLangImageBackend: + """Launch a local SGLang Diffusion server and return a connected backend. + + The server is started as a subprocess. This method blocks until the + server responds to health-check requests or ``startup_timeout_seconds`` + elapses. + + Args: + model_path: HuggingFace model ID or local model directory. + port: TCP port the server should listen on. + num_gpus: Tensor-parallelism degree (``--tp``). + dtype: Optional model weight dtype (``--dtype float16`` etc.). + api_key: API key used for subsequent requests. + timeout_seconds: Per-request HTTP timeout for inference calls. + startup_timeout_seconds: Seconds to wait for the server to become ready. + extra_server_args: Extra CLI flags appended to the launch command. + max_concurrent_requests: Concurrent HTTP image requests. + """ + base_url = f"http://127.0.0.1:{port}/v1" + proc = cls._launch_managed_server( + model_path=model_path, + port=port, + num_gpus=num_gpus, + dtype=dtype, + extra_args=extra_server_args or [], + ) + cls._wait_for_server(base_url, api_key, startup_timeout_seconds, proc) + return cls( + base_url=base_url, + api_key=api_key, + timeout_seconds=timeout_seconds, + request_model=model_path, + validate_server=False, # already confirmed ready + max_concurrent_requests=max_concurrent_requests, + _managed_process=proc, + ) + + @staticmethod + def _launch_managed_server( + model_path: str, + port: int, + num_gpus: int, + dtype: Optional[str], + extra_args: List[str], + ) -> subprocess.Popen: + """Spawn ``python -m sglang.launch_server`` and return the process handle.""" + cmd = [ + sys.executable, + "-m", + "sglang.launch_server", + "--model-path", + model_path, + "--port", + str(port), + "--tp", + str(num_gpus), + ] + if dtype: + cmd += ["--dtype", dtype] + cmd += extra_args + + logger.info("Starting managed SGLang server: %s", " ".join(cmd)) + + # Inherit the current environment so HF_HOME, CUDA_VISIBLE_DEVICES, etc. + # are passed through to the server process. + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=os.environ.copy(), + ) + logger.info("SGLang server process started (pid=%d)", proc.pid) + return proc + + @staticmethod + def _wait_for_server( + base_url: str, + api_key: str, + startup_timeout_seconds: int, + proc: subprocess.Popen, + ) -> None: + """Poll the server until it responds or the timeout elapses.""" + server_root = base_url.rstrip("/").removesuffix("/v1") + health_urls = [ + f"{server_root}/health", + f"{base_url}/models", + ] + deadline = time.monotonic() + startup_timeout_seconds + poll_interval = 2.0 + + logger.info( + "Waiting up to %ds for SGLang server to become ready \u2026", + startup_timeout_seconds, + ) + + while time.monotonic() < deadline: + # Abort early if the process already exited. + ret = proc.poll() + if ret is not None: + output = "" + if proc.stdout: + try: + output = proc.stdout.read().decode(errors="replace")[-2000:] + except Exception: + pass + raise RuntimeError( + f"SGLang server process exited unexpectedly with code {ret}.\n" + f"Last output:\n{output}" + ) + + for url in health_urls: + try: + req = urllib.request.Request( + url, + headers={"Authorization": f"Bearer {api_key}"}, + ) + with urllib.request.urlopen(req, timeout=5): + logger.info("SGLang server is ready at %s", base_url) + return + except Exception: + pass + + time.sleep(poll_interval) + + raise RuntimeError( + f"SGLang server did not become ready within {startup_timeout_seconds}s. " + "Check server logs or increase startup_timeout_seconds." + ) + + # ------------------------------------------------------------------ + # URL / HTTP helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_base_url(base_url: str) -> tuple[str, str]: + """Return ``(server_root_url, api_base_url)`` from a user URL.""" + normalized = base_url.rstrip("/") + + if normalized.endswith("/v1"): + server_root_url = normalized[: -len("/v1")].rstrip("/") + api_base_url = normalized + else: + server_root_url = normalized + api_base_url = f"{normalized}/v1" + + return server_root_url, api_base_url + + def _headers(self) -> Dict[str, str]: + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + } + + def _read_json(self, url: str, *, data: Optional[bytes] = None, timeout: Optional[int] = None) -> Dict[str, Any]: + """Issue an HTTP request and parse the response as JSON.""" + req = urllib.request.Request( + url, + data=data, + headers=self._headers(), + method="POST" if data is not None else "GET", + ) + + try: + with urllib.request.urlopen(req, timeout=timeout or self._timeout) as resp: + raw = resp.read().decode("utf-8") + except urllib.error.HTTPError as exc: + body_text = exc.read().decode(errors="replace") + raise RuntimeError( + f"SGLang server returned HTTP {exc.code} for {url}: {body_text}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError(f"Could not reach SGLang server at {url}: {exc}") from exc + + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"SGLang server returned non-JSON response from {url}: {raw[:500]}" + ) from exc + + if not isinstance(parsed, dict): + raise RuntimeError( + f"SGLang server returned unexpected JSON response from {url}: {parsed!r}" + ) + + return parsed + + # ------------------------------------------------------------------ + # Server connectivity + # ------------------------------------------------------------------ + + def _validate_server(self) -> None: + """Confirm the SGLang server is reachable before processing starts. + + SGLang Diffusion commonly exposes ``GET /models`` at the server root. + Some OpenAI-compatible deployments expose ``GET /v1/models`` instead, + so we try both to make local setups less brittle. + """ + candidate_urls = [ + f"{self._server_root_url}/models", + f"{self._api_base_url}/models", + ] + + errors: List[str] = [] + for url in candidate_urls: + try: + self._read_json(url, timeout=10) + logger.info("SGLang Diffusion server is reachable at %s", url) + return + except Exception as exc: + errors.append(f"{url}: {exc}") + + raise RuntimeError( + "Cannot reach SGLang Diffusion server. Ensure the server is running " + "before starting the pipeline with launch_mode='external'. Tried:\n" + + "\n".join(f" - {err}" for err in errors) + ) + + # ------------------------------------------------------------------ + # Payload / response handling + # ------------------------------------------------------------------ + + def _build_payload( + self, + prompt: str, + negative_prompt: Optional[str], + params: Dict[str, Any], + seed: Optional[int], + ) -> Dict[str, Any]: + """Build an OpenAI-compatible image generation payload.""" + payload: Dict[str, Any] = { + "prompt": prompt, + "response_format": "b64_json", + "n": 1, + } + + if self._request_model: + payload["model"] = self._request_model + + if negative_prompt: + payload["negative_prompt"] = negative_prompt + + if seed is not None: + payload["seed"] = int(seed) + + # OpenAI-compatible SGLang image API expects "size": "WIDTHxHEIGHT". + if "size" in params: + payload["size"] = params["size"] + else: + width = params.get("width") + height = params.get("height") + if width is not None and height is not None: + payload["size"] = f"{int(width)}x{int(height)}" + + # Common diffusion-specific knobs. These are accepted by many SGLang + # diffusion deployments, but any unsupported key will be rejected by the + # server with a clear HTTP error. + for key in ("num_inference_steps", "guidance_scale"): + if key in params and params[key] is not None: + payload[key] = params[key] + + # Forward extra model/pipeline-specific parameters, excluding fields + # that were normalized above. + reserved = { + "width", + "height", + "size", + "num_inference_steps", + "guidance_scale", + "generator", # Diffusers-only; should never be sent to SGLang. + } + for key, value in params.items(): + if key not in reserved and value is not None: + payload[key] = value + + return payload + + @staticmethod + def _prompt_preview(prompt: str, limit: int = 80) -> str: + compact = prompt.replace("\n", " ").strip() + return compact if len(compact) <= limit else compact[: limit - 3] + "..." + + def _decode_image_response(self, result: Dict[str, Any], prompt: str) -> Any: + """Decode the first ``b64_json`` image into a PIL Image.""" + try: + data = result["data"] + if not isinstance(data, list) or not data: + raise KeyError("data[0]") + + b64_data = data[0]["b64_json"] + if not isinstance(b64_data, str): + raise TypeError("b64_json is not a string") + except (KeyError, IndexError, TypeError) as exc: + raise RuntimeError( + "Unexpected SGLang image response for prompt " + f"{self._prompt_preview(prompt)!r}: {result!r}" + ) from exc + + try: + img_bytes = base64.b64decode(b64_data) + except (binascii.Error, ValueError) as exc: + raise RuntimeError( + "SGLang returned invalid base64 image data for prompt " + f"{self._prompt_preview(prompt)!r}" + ) from exc + + try: + with PILImage.open(io.BytesIO(img_bytes)) as img: + return img.convert("RGB") + except Exception as exc: + raise RuntimeError( + "Could not decode SGLang image response for prompt " + f"{self._prompt_preview(prompt)!r}" + ) from exc + + # ------------------------------------------------------------------ + # Single-sample API call + # ------------------------------------------------------------------ + + def _call_api( + self, + prompt: str, + negative_prompt: Optional[str], + params: Dict[str, Any], + seed: Optional[int], + ) -> Any: + """Call ``/v1/images/generations`` and return a PIL Image.""" + payload = self._build_payload(prompt, negative_prompt, params, seed) + body = json.dumps(payload).encode("utf-8") + + url = f"{self._api_base_url}/images/generations" + result = self._read_json(url, data=body, timeout=self._timeout) + return self._decode_image_response(result, prompt) + + # ------------------------------------------------------------------ + # Backend interface + # ------------------------------------------------------------------ + + def generate_batch( + self, + prompts: List[str], + negative_prompts: Optional[List[Optional[str]]], + params: Dict[str, Any], + seeds: List[Optional[int]], + ) -> List[Any]: + """Generate images through SGLang. + + The OpenAI-compatible image endpoint is logically per prompt. This + method preserves the backend interface by issuing one request per prompt. + Requests are sequential by default. Set ``max_concurrent_requests > 1`` + only after confirming that the local SGLang server handles concurrent + image generation reliably. + """ + if negative_prompts is not None and len(negative_prompts) != len(prompts): + raise ValueError( + f"Expected {len(prompts)} negative prompts, got {len(negative_prompts)}" + ) + + if seeds and len(seeds) != len(prompts): + raise ValueError(f"Expected {len(prompts)} seeds, got {len(seeds)}") + + def generate_one(i: int) -> Any: + negative_prompt = negative_prompts[i] if negative_prompts is not None else None + seed = seeds[i] if seeds else None + return self._call_api(prompts[i], negative_prompt, params, seed) + + if self._max_concurrent_requests == 1 or len(prompts) <= 1: + return [generate_one(i) for i in range(len(prompts))] + + with ThreadPoolExecutor(max_workers=self._max_concurrent_requests) as pool: + futures = [pool.submit(generate_one, i) for i in range(len(prompts))] + return [future.result() for future in futures] + + def shutdown(self) -> None: + """Shut down the backend. Terminates the managed server process if running.""" + if self._managed_process is not None: + proc = self._managed_process + self._managed_process = None + logger.info("Stopping managed SGLang server (pid=%d) …", proc.pid) + proc.terminate() + try: + proc.wait(timeout=30) + except subprocess.TimeoutExpired: + logger.warning( + "SGLang server (pid=%d) did not terminate within 30 s; killing.", + proc.pid, + ) + proc.kill() + proc.wait() + logger.info("SGLang server stopped.") \ No newline at end of file diff --git a/src/mmirage/core/process/processors/image_gen/config.py b/src/mmirage/core/process/processors/image_gen/config.py new file mode 100644 index 0000000..33fe844 --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/config.py @@ -0,0 +1,263 @@ +"""Configuration for image generation processor in MMIRAGE.""" + +from dataclasses import dataclass, field +from enum import Enum + +import logging +import os +from typing import Any, Dict, List, Optional, Sequence +from jinja2 import Environment, meta + +from mmirage.core.process.base import BaseProcessorConfig +from mmirage.core.process.base import ProcessorRegistry +from mmirage.core.process.variables import BaseVar, OutputVar + +logger = logging.getLogger(__name__) +env = Environment() + + +class ImageOutputMode(str, Enum): + """Output representation for generated images.""" + + PATH = "path" + PIL = "pil" + + +@dataclass +class DiffusersPipelineArgs: + """Runtime arguments used to initialize a Diffusers pipeline. + + Attributes: + model_path: Hugging Face model id or local path. + revision: Optional model revision (branch / tag / commit SHA). + variant: Optional file variant, e.g. ``"fp16"`` for ``*.fp16.safetensors``. + torch_dtype: Torch dtype as string. Common values: ``"float16"``, + ``"bfloat16"``, ``"float32"``, ``"auto"``. + device: Target device: ``"auto"``, ``"cuda"``, ``"cpu"``, or an + explicit device string such as ``"cuda:1"``. + ``"auto"`` distributes across all available GPUs when more than + one is present (via ``device_map='auto'``), or falls back to CPU. + enable_attention_slicing: Enable attention slicing to reduce VRAM + usage. Defaults to ``False`` because it can slow down modern + CUDA setups; enable only when VRAM is constrained. + local_files_only: If ``True``, only load from local cache; never + contact the HuggingFace Hub. Useful for air-gapped clusters. + cache_dir: Override the HuggingFace cache directory. + trust_remote_code: Allow custom model code from the Hub repository. + custom_pipeline: Custom pipeline module name forwarded to + ``from_pretrained``. + """ + + model_path: str = "stable-diffusion-v1-5/stable-diffusion-v1-5" + revision: Optional[str] = None + variant: Optional[str] = None + torch_dtype: str = "float16" + device: str = "auto" + enable_attention_slicing: bool = False + local_files_only: bool = False + cache_dir: Optional[str] = None + trust_remote_code: bool = False + custom_pipeline: Optional[str] = None + + +@dataclass +class SGLangBackendConfig: + """Configuration for the SGLang Diffusion server backend. + + Two launch modes are supported: + + - ``external`` — MMIRAGE connects to an already-running SGLang server. + The user is responsible for starting it before the pipeline runs. + - ``managed`` — MMIRAGE spawns a local SGLang server as a subprocess, + waits for it to become ready, and shuts it down when the shard finishes. + This requires ``model_path`` and SGLang to be installed. + + Attributes: + launch_mode: ``"external"`` or ``"managed"``. + base_url: Base URL of the server (``http://host:port/v1``). Ignored + when ``launch_mode='managed'`` and ``port`` is set; inferred + automatically in that case. + api_key: ``Authorization: Bearer`` key. ``"EMPTY"`` for local servers. + timeout_seconds: Per-request HTTP timeout in seconds. + model_path: HuggingFace model ID or local path forwarded to the server. + Required for ``launch_mode='managed'``; optional for ``'external'`` + (sent as the ``model`` field in each request if supplied). + port: Port the managed server should listen on. Defaults to ``30010``. + num_gpus: Tensor-parallelism degree (``--tp``). Defaults to ``1``. + dtype: Model weight dtype forwarded as ``--dtype``. E.g. ``"float16"``. + startup_timeout_seconds: Maximum seconds to wait for the managed server + to become ready before raising an error. + extra_server_args: Additional CLI arguments appended verbatim to the + ``python -m sglang.launch_server`` command, e.g. + ``["--mem-fraction-static", "0.9"]``. + """ + + launch_mode: str = "external" + base_url: str = "http://127.0.0.1:30010/v1" + api_key: str = "EMPTY" + timeout_seconds: int = 900 + model_path: Optional[str] = None + + # managed-mode fields + port: int = 30010 + num_gpus: int = 1 + dtype: Optional[str] = None + startup_timeout_seconds: int = 120 + extra_server_args: List[str] = field(default_factory=list) + + def __post_init__(self) -> None: + if self.launch_mode not in ("external", "managed"): + raise ValueError( + f"Unsupported SGLang launch_mode={self.launch_mode!r}. " + "Choose 'external' (server already running) or " + "'managed' (MMIRAGE starts the server automatically)." + ) + if self.launch_mode == "managed" and not self.model_path: + raise ValueError( + "launch_mode='managed' requires model_path to be set so MMIRAGE " + "knows which model to pass to the SGLang server." + ) + if self.launch_mode == "managed": + # Derive base_url from port so users don't have to repeat it. + self.base_url = f"http://127.0.0.1:{self.port}/v1" + + +@dataclass +class ImageGenConfig(BaseProcessorConfig): + """Configuration for the backend-neutral image generation processor. + + Attributes: + backend: Image generation backend to use. One of ``"diffusers"`` + (in-process Diffusers pipeline) or ``"sglang"`` (local SGLang + Diffusion server). + pipeline_args: Diffusers pipeline arguments (used when + ``backend="diffusers"``). + sglang: SGLang server configuration (used when ``backend="sglang"``). + default_sampling_params: Default generation kwargs forwarded to every + pipeline/server call (e.g. ``num_inference_steps``, + ``guidance_scale``). + parallel_inference: If ``True``, generate a full chunk of prompts in a + single batched pipeline call (Diffusers backend) or concurrent + server calls. Chunks that fail are retried sample-by-sample. + parallel_chunk_size: Maximum number of samples per batched call. + ``None`` means use the full mapper batch size. + output_dir: Directory where generated images are written when + ``output_mode="path"``. Supports ``~`` expansion. + file_format: Image file format for saved outputs (e.g. ``"png"``, + ``"jpg"``). + """ + + backend: str = "diffusers" + pipeline_args: DiffusersPipelineArgs = field(default_factory=DiffusersPipelineArgs) + sglang: Optional[SGLangBackendConfig] = None + default_sampling_params: Dict[str, Any] = field(default_factory=dict) + parallel_inference: bool = True + parallel_chunk_size: Optional[int] = 4 + output_dir: str = "~/.cache/MMIRAGE/generated_images" + file_format: str = "png" + + def __post_init__(self) -> None: + """Validate configuration.""" + if self.backend not in ("diffusers", "sglang"): + raise ValueError( + f"Unsupported image_gen backend={self.backend!r}. " + "Choose 'diffusers' or 'sglang'." + ) + if self.backend == "sglang" and self.sglang is None: + raise ValueError( + "backend='sglang' requires a 'sglang:' configuration block." + ) + if self.parallel_chunk_size is not None and self.parallel_chunk_size <= 0: + raise ValueError( + f"parallel_chunk_size must be a positive integer, got {self.parallel_chunk_size!r}. " + "Set to None to use the full batch size." + ) + if not self.file_format: + logger.warning("file_format is empty; defaulting to 'png'.") + self.file_format = "png" + + def get_output_dir(self) -> str: + """Return the normalised absolute output directory path.""" + return os.path.abspath(os.path.expanduser(self.output_dir)) + + +# --------------------------------------------------------------------------- +# Backward-compatibility alias +# --------------------------------------------------------------------------- +#: ``DiffusersImageGenConfig`` is a legacy alias for :class:`ImageGenConfig`. +DiffusersImageGenConfig = ImageGenConfig + + +@dataclass +class ImageGenOutputVar(OutputVar): + """Output variable generated by image generation processor. + + Attributes: + prompt: Jinja2 template used as positive prompt. + negative_prompt: Optional Jinja2 template used as negative prompt. + output_mode: Output representation: "path" (default) or "pil". + filename_template: Optional Jinja2 template used for saved image filename stem. + Supported internal variables: ``__sample_index`` (shard-local row index, + i.e. the position within this shard's output — combine with ``__shard_id`` + for global uniqueness), ``__output_name``, ``__shard_id``, + ``__source_hash`` (8-char SHA-256 of input values). + All input variables (e.g. ``text``) are also available. + width: Optional image width override. + height: Optional image height override. + num_inference_steps: Optional sampling steps override. + guidance_scale: Optional guidance scale override. + seed: Optional deterministic seed. If set, sample index is added for uniqueness. + """ + + prompt: str = "" + negative_prompt: str = "" + output_mode: ImageOutputMode = ImageOutputMode.PATH + filename_template: str = "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" + width: Optional[int] = None + height: Optional[int] = None + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + seed: Optional[int] = None + + def is_computable(self, vars: Sequence[BaseVar]) -> bool: + """Check if all variables referenced in templates are available.""" + reserved = {"__sample_index", "__output_name", "__shard_id", "__source_hash"} + var_names = {v.name for v in vars} + + # Prompt/negative_prompt are rendered from env.to_dict() only — reserved + # vars are not injected there, so treat them as undeclared in those templates. + prompt_templates: List[str] = [self.prompt] + if self.negative_prompt: + prompt_templates.append(self.negative_prompt) + + undeclared: set[str] = set() + for template in prompt_templates: + parsed_content = env.parse(template) + template_vars = meta.find_undeclared_variables(parsed_content) + undeclared |= template_vars - var_names + + # filename_template is rendered with reserved vars injected, so allow them. + if self.filename_template: + parsed_content = env.parse(self.filename_template) + template_vars = meta.find_undeclared_variables(parsed_content) + undeclared |= template_vars - var_names - reserved + + if undeclared: + logger.warning( + f"⚠️ Undeclared variables found for {self.name}: {undeclared}" + ) + return False + + try: + self.output_mode = ImageOutputMode(self.output_mode) + except ValueError: + logger.warning( + f"⚠️ Invalid output_mode for {self.name}: {self.output_mode}. " + f"Expected one of {[m.value for m in ImageOutputMode]}" + ) + return False + + return True + + +ProcessorRegistry.register_types("image_gen", ImageGenConfig, ImageGenOutputVar) diff --git a/src/mmirage/core/process/processors/image_gen/image_gen_processor.py b/src/mmirage/core/process/processors/image_gen/image_gen_processor.py new file mode 100644 index 0000000..391160a --- /dev/null +++ b/src/mmirage/core/process/processors/image_gen/image_gen_processor.py @@ -0,0 +1,425 @@ +"""Image generation processor implementation using pluggable backends.""" + +from __future__ import annotations + +import hashlib +import logging +import os +import re +import socket +import tempfile +import uuid +from typing import Any, Dict, List, Optional + +import jinja2 + +from mmirage.core.process.base import BaseProcessor, ProcessorRegistry +from mmirage.core.process.processors.image_gen.backends.base import ImageGenerationBackend +from mmirage.core.process.processors.image_gen.config import ( + ImageGenConfig, + ImageGenOutputVar, + ImageOutputMode, +) +from mmirage.core.process.variables import VariableEnvironment + +try: + from typing import override # Python 3.12+ +except ImportError: # pragma: no cover + from typing_extensions import override # type: ignore + + +logger = logging.getLogger(__name__) + +_SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+") + + +def _sanitize_filename(filename: str) -> str: + """Return a filesystem-safe filename stem.""" + normalized = _SAFE_FILENAME_RE.sub("_", filename).strip("._") + return normalized or "image" + + +def _create_backend(config: ImageGenConfig) -> ImageGenerationBackend: + """Instantiate the configured image generation backend.""" + if config.backend == "diffusers": + from mmirage.core.process.processors.image_gen.backends.diffusers_backend import ( + DiffusersImageBackend, + ) + return DiffusersImageBackend(config.pipeline_args) + + if config.backend == "sglang": + from mmirage.core.process.processors.image_gen.backends.sglang_backend import ( + SGLangImageBackend, + ) + assert config.sglang is not None # validated in __post_init__ + sglang = config.sglang + if sglang.launch_mode == "managed": + return SGLangImageBackend.from_managed_config( + model_path=sglang.model_path, # type: ignore[arg-type] # validated non-None + port=sglang.port, + num_gpus=sglang.num_gpus, + dtype=sglang.dtype, + api_key=sglang.api_key, + timeout_seconds=sglang.timeout_seconds, + startup_timeout_seconds=sglang.startup_timeout_seconds, + extra_server_args=sglang.extra_server_args, + ) + # launch_mode == "external" + return SGLangImageBackend( + base_url=sglang.base_url, + api_key=sglang.api_key, + timeout_seconds=sglang.timeout_seconds, + model_path=sglang.model_path, + ) + + raise ValueError(f"Unknown image_gen backend={config.backend!r}") + + +@ProcessorRegistry.register("image_gen", ImageGenConfig, ImageGenOutputVar) +class ImageGenProcessor(BaseProcessor[ImageGenOutputVar]): + """Processor that generates images from prompts using a pluggable backend. + + Supported backends: ``diffusers`` (in-process Diffusers pipeline) and + ``sglang`` (local SGLang Diffusion server over HTTP). + + Responsibilities of this processor: + - Render Jinja2 prompt and filename templates. + - Compute deterministic, shard-aware seeds. + - Chunk batches and call the backend's ``generate_batch`` method. + - Fall back to per-sample sequential generation if a batch chunk fails. + - Save images atomically to disk (``output_mode="path"``) or pass PIL + images through directly (``output_mode="pil"``). + """ + + def __init__(self, config: ImageGenConfig, shard_id: int = 0, **kwargs) -> None: + super().__init__(config, shard_id=shard_id, **kwargs) + + self._backend: ImageGenerationBackend = _create_backend(config) + self._default_sampling_params = dict(config.default_sampling_params) + self._parallel_inference = config.parallel_inference + self._parallel_chunk_size = config.parallel_chunk_size + + self._output_dir = config.get_output_dir() + self._file_format = config.file_format.lower() + os.makedirs(self._output_dir, exist_ok=True) + + self._shard_id = shard_id + # Counts the total number of samples processed by this instance. + # Used to derive shard-local sample indices for filenames and seeds. + self._sample_counter = 0 + run_token = uuid.uuid4().hex[:8] + self._run_id = f"{socket.gethostname()}.{os.getpid()}.{run_token}" + + # ------------------------------------------------------------------ + # Seed and param helpers + # ------------------------------------------------------------------ + + def _compute_seeds( + self, + base_seed: int, + batch_offset: int, + count: int, + ) -> List[int]: + """Compute per-sample deterministic seeds that are unique across shards. + + The seed for a sample is: + + ``base_seed + shard_id * 1_000_000_000 + sample_counter + batch_offset + i`` + + This guarantees that different shards with the same ``base_seed`` + produce different images even when their local sample indices overlap. + + Args: + base_seed: The ``seed`` value from the output variable config. + batch_offset: Position of the first sample in this call within the + current mapper batch (0 for the first chunk). + count: Number of seeds to produce. + """ + base = base_seed + self._shard_id * 1_000_000_000 + self._sample_counter + batch_offset + return [base + i for i in range(count)] + + def _build_params(self, output_var: ImageGenOutputVar) -> Dict[str, Any]: + """Build the generation kwargs dict from config defaults and per-var overrides.""" + params = dict(self._default_sampling_params) + if output_var.width is not None: + params["width"] = output_var.width + if output_var.height is not None: + params["height"] = output_var.height + if output_var.num_inference_steps is not None: + params["num_inference_steps"] = output_var.num_inference_steps + if output_var.guidance_scale is not None: + params["guidance_scale"] = output_var.guidance_scale + return params + + # ------------------------------------------------------------------ + # Filename and file saving helpers + # ------------------------------------------------------------------ + + @staticmethod + def _compute_source_hash(env: VariableEnvironment) -> str: + """Return an 8-character SHA-256 hex digest of all input variable values.""" + payload = str(sorted(env.to_dict().items())) + return hashlib.sha256(payload.encode()).hexdigest()[:8] + + def _render_filename( + self, + filename_template: jinja2.Template, + output_var: ImageGenOutputVar, + env: VariableEnvironment, + sample_index: int, + ) -> str: + """Render the output filename stem and return ``stem.ext``. + + ``sample_index`` is the shard-local index (``self._sample_counter`` + + position within the current batch). Combine with ``__shard_id`` in + the template for globally unique filenames. + """ + context = dict(env.to_dict()) + context["__sample_index"] = sample_index + context["__output_name"] = output_var.name + context["__shard_id"] = self._shard_id + context["__source_hash"] = self._compute_source_hash(env) + stem = _sanitize_filename(filename_template.render(**context)) + return f"{stem}.{self._file_format}" + + def _save_image(self, image: Any, filename: str) -> str: + """Persist a PIL image atomically and return the absolute path.""" + stem, ext = os.path.splitext(filename) + path = os.path.join(self._output_dir, filename) + if os.path.exists(path): + path = os.path.join(self._output_dir, f"{stem}.{self._run_id}{ext}") + + tmp_fd, tmp_path = tempfile.mkstemp(dir=self._output_dir, suffix=ext) + try: + os.close(tmp_fd) + image.save(tmp_path) + os.replace(tmp_path, path) + except Exception: + try: + os.unlink(tmp_path) + except OSError as cleanup_err: + logger.warning("Failed to clean up temp file %r: %s", tmp_path, cleanup_err) + raise + return path + + # ------------------------------------------------------------------ + # Chunk-level generation + # ------------------------------------------------------------------ + + def _collect_results( + self, + chunk: List[VariableEnvironment], + images: List[Any], + output_var: ImageGenOutputVar, + filename_template: jinja2.Template, + batch_offset: int, + ) -> List[VariableEnvironment]: + """Map backend images back to updated VariableEnvironments.""" + updated: List[VariableEnvironment] = [] + for i, (env, image) in enumerate(zip(chunk, images)): + sample_index = self._sample_counter + batch_offset + i + if output_var.output_mode == ImageOutputMode.PIL: + value = image + else: + filename = self._render_filename(filename_template, output_var, env, sample_index) + value = self._save_image(image, filename) + updated.append(env.with_variable(output_var.name, value, is_image=True)) + return updated + + def _generate_chunk_batch( + self, + chunk: List[VariableEnvironment], + output_var: ImageGenOutputVar, + prompt_template: jinja2.Template, + negative_prompt_template: Optional[jinja2.Template], + filename_template: jinja2.Template, + batch_offset: int, + ) -> List[VariableEnvironment]: + """Generate an entire chunk with a single batched backend call.""" + prompts = [prompt_template.render(**env.to_dict()) for env in chunk] + neg_prompts: Optional[List[Optional[str]]] = ( + [negative_prompt_template.render(**env.to_dict()) for env in chunk] + if negative_prompt_template is not None + else None + ) + seeds: List[Optional[int]] = ( + self._compute_seeds(int(output_var.seed), batch_offset, len(chunk)) + if output_var.seed is not None + else [None] * len(chunk) + ) + params = self._build_params(output_var) + images = self._backend.generate_batch(prompts, neg_prompts, params, seeds) + return self._collect_results(chunk, images, output_var, filename_template, batch_offset) + + def _generate_chunk_sequential( + self, + chunk: List[VariableEnvironment], + output_var: ImageGenOutputVar, + prompt_template: jinja2.Template, + negative_prompt_template: Optional[jinja2.Template], + filename_template: jinja2.Template, + batch_offset: int, + ) -> List[VariableEnvironment]: + """Generate samples one-by-one, tolerating per-sample failures.""" + updated: List[VariableEnvironment] = [] + params = self._build_params(output_var) + + for i, env in enumerate(chunk): + sample_index = self._sample_counter + batch_offset + i + try: + prompt = prompt_template.render(**env.to_dict()) + neg: Optional[str] = ( + negative_prompt_template.render(**env.to_dict()) + if negative_prompt_template is not None + else None + ) + seed_val: Optional[int] = ( + self._compute_seeds(int(output_var.seed), batch_offset + i, 1)[0] + if output_var.seed is not None + else None + ) + images = self._backend.generate_batch( + [prompt], + [neg] if neg is not None else None, + params, + [seed_val], + ) + if len(images) != 1: + raise RuntimeError( + f"Expected 1 image from backend in sequential mode, got {len(images)}" + ) + image = images[0] + if output_var.output_mode == ImageOutputMode.PIL: + value = image + else: + filename = self._render_filename(filename_template, output_var, env, sample_index) + value = self._save_image(image, filename) + updated.append(env.with_variable(output_var.name, value, is_image=True)) + except Exception as exc: + logger.error( + "Image generation failed for output '%s' at sample %d: %s", + output_var.name, + sample_index, + exc, + ) + updated.append(env.with_variable(output_var.name, None, is_image=True)) + + return updated + + # ------------------------------------------------------------------ + # Batch-level orchestration + # ------------------------------------------------------------------ + + def _batch_process_parallel( + self, + batch: List[VariableEnvironment], + output_var: ImageGenOutputVar, + prompt_template: jinja2.Template, + negative_prompt_template: Optional[jinja2.Template], + filename_template: jinja2.Template, + ) -> List[VariableEnvironment]: + """Process the full mapper batch in chunks with per-chunk fallback. + + For each chunk the processor tries a single batched backend call. + If that call fails only the failing chunk is retried sample-by-sample; + already-successful chunks are never re-generated. + """ + chunk_size = self._parallel_chunk_size or len(batch) + updated: List[VariableEnvironment] = [] + + for batch_offset in range(0, len(batch), chunk_size): + chunk = batch[batch_offset : batch_offset + chunk_size] + try: + updated.extend( + self._generate_chunk_batch( + chunk, + output_var, + prompt_template, + negative_prompt_template, + filename_template, + batch_offset, + ) + ) + except Exception as exc: + logger.warning( + "Batch generation failed for chunk at offset %d " + "(samples %d–%d); falling back to sequential for this chunk. " + "Reason: %s", + batch_offset, + self._sample_counter + batch_offset, + self._sample_counter + batch_offset + len(chunk) - 1, + exc, + ) + updated.extend( + self._generate_chunk_sequential( + chunk, + output_var, + prompt_template, + negative_prompt_template, + filename_template, + batch_offset, + ) + ) + + self._sample_counter += len(batch) + return updated + + def _batch_process_sequential( + self, + batch: List[VariableEnvironment], + output_var: ImageGenOutputVar, + prompt_template: jinja2.Template, + negative_prompt_template: Optional[jinja2.Template], + filename_template: jinja2.Template, + ) -> List[VariableEnvironment]: + """Process all samples one by one (used when parallel_inference=False).""" + result = self._generate_chunk_sequential( + batch, + output_var, + prompt_template, + negative_prompt_template, + filename_template, + batch_offset=0, + ) + self._sample_counter += len(batch) + return result + + # ------------------------------------------------------------------ + # Public processor interface + # ------------------------------------------------------------------ + + @override + def batch_process_sample( + self, batch: List[VariableEnvironment], output_var: ImageGenOutputVar + ) -> List[VariableEnvironment]: + """Generate images for each sample in the batch.""" + prompt_template = jinja2.Template(output_var.prompt) + negative_prompt_template = ( + jinja2.Template(output_var.negative_prompt) + if output_var.negative_prompt + else None + ) + filename_template = jinja2.Template(output_var.filename_template) + + if self._parallel_inference and len(batch) > 1: + return self._batch_process_parallel( + batch, + output_var, + prompt_template, + negative_prompt_template, + filename_template, + ) + + return self._batch_process_sequential( + batch, + output_var, + prompt_template, + negative_prompt_template, + filename_template, + ) + + @override + def shutdown(self) -> None: + """Release backend resources (GPU memory, HTTP connections, …).""" + self._backend.shutdown() + diff --git a/src/mmirage/shard_process.py b/src/mmirage/shard_process.py index 66e8529..098374b 100644 --- a/src/mmirage/shard_process.py +++ b/src/mmirage/shard_process.py @@ -9,10 +9,13 @@ import traceback from typing import Any, Dict, List +from datasets import DatasetDict, Image as HFImage + from mmirage.config.utils import load_mmirage_config from mmirage.core.loader.base import DatasetLike from mmirage.core.loader.utils import load_datasets_from_configs from mmirage.core.process.mapper import MMIRAGEMapper +from mmirage.core.process.variables import OutputVar from mmirage.core.writer.renderer import TemplateRenderer from mmirage.shard_utils import ( _cleanup_old_shard_data, @@ -30,6 +33,59 @@ logger = logging.getLogger(__name__) +def _image_path_schema_cols( + output_vars: List[OutputVar], + output_schema: Dict[str, Any], + renderer: TemplateRenderer, +) -> List[str]: + """Return output-schema column names that map directly to image-path output variables. + + Uses duck typing on ``output_mode`` so no concrete processor import is needed. + """ + image_path_var_names = { + v.name + for v in output_vars + if getattr(v, "output_mode", None) == "path" + } + return [ + key + for key, tmpl in output_schema.items() + if isinstance(tmpl, str) + and renderer.is_single_variable_template(tmpl) in image_path_var_names + ] + + +def _cast_image_columns(ds: DatasetLike, cols: List[str]) -> DatasetLike: + """Cast image-path string columns to the HuggingFace Image feature. + + Empty strings (failure fallbacks) are normalised to ``None`` so that + HuggingFace stores them as missing rather than raising a decode error. + When ``save_to_disk`` is called, HuggingFace reads each path from disk + and embeds the raw bytes in the Arrow file, making the shard portable. + """ + def _normalise_col(batch: Dict[str, Any], col: str) -> Dict[str, Any]: + return {col: [v if v else None for v in batch[col]]} + + if isinstance(ds, DatasetDict): + for col in cols: + for split in list(ds.keys()): + if col in ds[split].column_names: + ds[split] = ds[split].map( + _normalise_col, batched=True, fn_kwargs={"col": col}, desc=f"Normalising {col}", + load_from_cache_file=False, + ) + ds[split] = ds[split].cast_column(col, HFImage()) + else: + for col in cols: + if col in ds.column_names: + ds = ds.map( + _normalise_col, batched=True, fn_kwargs={"col": col}, desc=f"Normalising {col}", + load_from_cache_file=False, + ) + ds = ds.cast_column(col, HFImage()) + return ds + + def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, @@ -112,44 +168,60 @@ def main(): cfg.processors, processing_params.inputs, processing_params.outputs, + shard_id=shard_id, ) renderer = TemplateRenderer(processing_params.output_schema) - ds_processed_all: List[DatasetLike] = [] - for ds_idx, ds_shard in enumerate(ds_all_shard): - ds_config = datasets_config[ds_idx] - if processing_params.remove_columns: - remove_columns = _remove_columns(ds_shard) - else: - remove_columns = [] - - logger.info( - f"Processing dataset {ds_idx} for shard {shard_id}: " - f"path={ds_config.path}, output_dir={ds_config.output_dir}" - ) - - ds_processed = ds_shard.map( - rewrite_batch, - batched=True, - batch_size=loading_params.get_batch_size(), - load_from_cache_file=False, - desc=f"Shard {shard_id}/{last_shard_id} dataset {ds_idx}", - fn_kwargs={ - "mapper": mapper, - "renderer": renderer, - "image_base_path": ds_config.image_base_path, - }, - remove_columns=remove_columns, - ) - ds_processed_all.append(ds_processed) - - for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): - out_dir = _dataset_out_dir(shard_id, ds_config) - _save_dataset_atomic(ds_processed, out_dir) - logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") - - _mark_success(state_dir) - logger.info(f"✅ Logical shard {shard_id} completed successfully") + try: + ds_processed_all: List[DatasetLike] = [] + for ds_idx, ds_shard in enumerate(ds_all_shard): + ds_config = datasets_config[ds_idx] + if processing_params.remove_columns: + remove_columns = _remove_columns(ds_shard) + else: + remove_columns = [] + + logger.info( + f"Processing dataset {ds_idx} for shard {shard_id}: " + f"path={ds_config.path}, output_dir={ds_config.output_dir}" + ) + + ds_processed = ds_shard.map( + rewrite_batch, + batched=True, + batch_size=loading_params.get_batch_size(), + load_from_cache_file=False, + desc=f"Shard {shard_id}/{last_shard_id} dataset {ds_idx}", + fn_kwargs={ + "mapper": mapper, + "renderer": renderer, + "image_base_path": ds_config.image_base_path, + }, + remove_columns=remove_columns, + ) + + image_cols = _image_path_schema_cols( + processing_params.outputs, + processing_params.output_schema, + renderer, + ) + if image_cols: + ds_processed = _cast_image_columns(ds_processed, image_cols) + logger.info(f"Cast image column(s) to HF Image feature: {image_cols}") + + ds_processed_all.append(ds_processed) + + for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): + out_dir = _dataset_out_dir(shard_id, ds_config) + _save_dataset_atomic(ds_processed, out_dir) + logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") + + _mark_success(state_dir) + logger.info(f"✅ Logical shard {shard_id} completed successfully") + + finally: + mapper.shutdown() + logger.info("Processors shut down.") except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" diff --git a/tests/mock_data_image_gen/data.jsonl b/tests/mock_data_image_gen/data.jsonl new file mode 100644 index 0000000..ac7a365 --- /dev/null +++ b/tests/mock_data_image_gen/data.jsonl @@ -0,0 +1,10 @@ +{"text": "a cat sitting on a windowsill"} +{"text": "a dog running in a park"} +{"text": "a red fox in a snowy forest"} +{"text": "a colorful parrot on a tree branch"} +{"text": "a panda eating bamboo"} +{"text": "a horse galloping across a field"} +{"text": "a lion resting under an acacia tree"} +{"text": "a dolphin jumping out of the ocean"} +{"text": "a rabbit in a flower garden"} +{"text": "a brown bear near a mountain lake"}