Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ For testing and scripts that make use of the library, it is advised to create a
- Image inputs for multimodal processing
- Parallelizable with multi-node support
- The training pipeline uses distributed inference with sharding
- Support a variety of LLMs and VLMs (Vision-Language Models)
- Support a variety of LLMs, VLMs (Vision-Language Models), and image generation models
- Support any dataset schemas (configurable with the YAML format)
- The ability to either output a JSON (or any other structured format) or plain text
- Modular architecture with pluggable processors, loaders, and writers
Expand Down Expand Up @@ -123,7 +123,7 @@ execution_params:

Configuration explanation:

- `processors`: List of processor configurations. Currently supports `llm` type for LLM-based generation.
- `processors`: List of processor configurations. Supports `llm` (text/VLM generation) and `image_gen` (text-to-image generation).
- `loading_params`: Parameters for loading and sharding datasets.
- `state_dir`: Optional shared directory for shard status/retry state. Defaults to `~/.cache/MMIRAGE/state_dir`.
- `datasets`: List of dataset configurations with path, type, and output directory.
Expand Down Expand Up @@ -191,6 +191,66 @@ execution_params:
retry: false
```

### Image generation: Text-to-image pipeline

MMIRAGE also supports image generation with Diffusers models:

```yaml
processors:
- type: image_gen
pipeline_args:
model_path: stable-diffusion-v1-5/stable-diffusion-v1-5
torch_dtype: float16
device: auto
enable_attention_slicing: true
default_sampling_params:
num_inference_steps: 20
guidance_scale: 7.5
output_dir: /path/to/generated/images
file_format: png

loading_params:
state_dir: /path/to/state/dir
datasets:
- path: /path/to/prompts.jsonl
type: JSONL
output_dir: /path/to/output/shards
num_shards: 1
shard_id: 0
batch_size: 8

processing_params:
inputs:
- name: prompt_text
key: text

outputs:
- name: generated_image
type: image_gen
output_mode: path # "path" or "pil"
filename_template: "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}"
width: 512
height: 512
prompt: |
Create an illustration of:
{{ prompt_text }}

remove_columns: false
output_schema:
text: "{{ prompt_text }}"
image: "{{ generated_image }}"

execution_params:
mode: local
retry: false
```

Install optional image generation dependencies before running this config:

```bash
pip install -e .[image_gen]
```

Key multimodal features:
- `chat_template`: Specify the VLM chat template (e.g., `qwen2-vl`)
- `type: image`: Mark input variables as images
Expand Down
53 changes: 53 additions & 0 deletions configs/config_mock_image_gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
processors:
- type: image_gen
pipeline_args:
model_path: stable-diffusion-v1-5/stable-diffusion-v1-5
torch_dtype: float16
device: auto
enable_attention_slicing: true
default_sampling_params:
num_inference_steps: 20
guidance_scale: 7.5
parallel_inference: true
parallel_chunk_size: 4
output_dir: tests/output/image_gen/generated_images
file_format: png

loading_params:
state_dir: tests/output/image_gen/_pipeline_state
datasets:
- path: tests/mock_data_image_gen/data.jsonl
type: JSONL
output_dir: tests/output/image_gen

num_shards: 1
shard_id: 0
batch_size: 4

processing_params:
inputs:
- name: text
key: text

outputs:
- name: generated_image
type: image_gen
output_mode: path
filename_template: "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}"
width: 512
height: 512
prompt: |
Create a clean and detailed illustration for:
{{ text }}

remove_columns: false
output_schema:
prompt_source: "{{ text }}"
image: "{{ generated_image }}"

execution_params:
mode: local
retry: false
merge: false
report_dir: ~/reports
hf_home: ~/hf
Comment thread
qchapp marked this conversation as resolved.
Outdated
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ dev = [
"ipykernel",
"pytest",
]
image_gen = [
"diffusers>=0.31.0",
"accelerate>=0.33.0",
"safetensors>=0.4.4",
]

[project.scripts]
mmirage = "mmirage.cli:main"
Expand Down
1 change: 1 addition & 0 deletions src/mmirage/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# to construct config/output-var objects from YAML without importing heavy
# processor implementations (e.g. torch/transformers).
import mmirage.core.process.processors.llm.config # noqa: F401
import mmirage.core.process.processors.image_gen.config # noqa: F401
import mmirage.core.loader.jsonl # noqa: F401
import mmirage.core.loader.local_hf # noqa: F401

Expand Down
9 changes: 7 additions & 2 deletions src/mmirage/core/process/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ class BaseProcessor(abc.ABC, Generic[C]):
config: Configuration object for this processor.
"""

def __init__(self, config: BaseProcessorConfig) -> None:
def __init__(self, config: BaseProcessorConfig, **kwargs) -> None:
"""Initialize the processor with configuration.

Args:
config: Configuration object for this processor.
**kwargs: Ignored; allows subclasses to forward unknown keyword
arguments (e.g. ``shard_id``) without raising TypeError.
"""
super().__init__()
self.config = config
Comment thread
qchapp marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -84,7 +86,10 @@ class ProcessorRegistry:
# Import processor implementations lazily because they may depend on heavy
# libraries (torch/transformers). Config/output-var types are registered via
# mmirage.config.utils importing the relevant config modules.
_lazy_processor_imports = {"llm": "mmirage.core.process.processors.llm.llm_processor"}
_lazy_processor_imports = {
"llm": "mmirage.core.process.processors.llm.llm_processor",
"image_gen": "mmirage.core.process.processors.image_gen.image_gen_processor",
}

@classmethod
def register_types(
Expand Down
4 changes: 3 additions & 1 deletion src/mmirage/core/process/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def __init__(
processor_configs: List[BaseProcessorConfig],
input_vars: List[InputVar],
output_vars: List[OutputVar],
shard_id: int = 0,
) -> None:
"""Initialize the MMIRAGE mapper.

Args:
processor_configs: List of processor configurations.
input_vars: List of input variable definitions.
output_vars: List of output variable definitions.
shard_id: Shard index for this worker, forwarded to processors.
"""
self.processors: Dict[str, BaseProcessor] = dict()
self.input_vars = input_vars
Expand All @@ -45,7 +47,7 @@ def __init__(
processor_cls = AutoProcessor.from_name(config.type)
logger.info(f"✅ Successfully loaded processor of type {config.type}")

self.processors[config.type] = processor_cls(config)
self.processors[config.type] = processor_cls(config, shard_id=shard_id)
Comment thread
qchapp marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the shard_id is currently ignored by LLMProcessor, maybe make it use it as well? it seems to be used only for computing the render filename


def validate_vars(self) -> bool:
"""Validate that all output variables are computable.
Expand Down
6 changes: 6 additions & 0 deletions src/mmirage/core/process/processors/image_gen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Image generation processor implementation.

This module provides a dedicated processor for text-to-image generation tasks
using Diffusers pipelines. It can emit either saved image paths or in-memory
PIL images.
Comment on lines +4 to +5
"""
137 changes: 137 additions & 0 deletions src/mmirage/core/process/processors/image_gen/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Configuration for image generation processor in MMIRAGE."""

from dataclasses import dataclass, field

import logging
import os
from typing import Any, Dict, List, Literal, Optional, Sequence, TypeAlias
from jinja2 import Environment, meta

from mmirage.core.process.base import BaseProcessorConfig
from mmirage.core.process.base import ProcessorRegistry
from mmirage.core.process.variables import BaseVar, OutputVar

logger = logging.getLogger(__name__)
env = Environment()

ImageOutputMode: TypeAlias = Literal["path", "pil"]


@dataclass
class DiffusersPipelineArgs:
"""Runtime arguments used to initialize a Diffusers pipeline.

Attributes:
model_path: Hugging Face model id or local path.
revision: Optional model revision.
torch_dtype: Torch dtype as string. Common values: "float16", "bfloat16", "float32", "auto".
device: Target device: "auto", "cuda", "cpu", or explicit device string.
enable_attention_slicing: Enable attention slicing when available to reduce VRAM usage.
"""

model_path: str = "stable-diffusion-v1-5/stable-diffusion-v1-5"
revision: Optional[str] = None
torch_dtype: str = "float16"
device: str = "auto"
enable_attention_slicing: bool = True


@dataclass
class DiffusersImageGenConfig(BaseProcessorConfig):
"""Configuration for image generation processor.

Attributes:
pipeline_args: Arguments used to initialize the Diffusers pipeline.
default_sampling_params: Default generation kwargs passed to pipeline calls.
parallel_inference: If True, process batch samples in parallel via a single batched pipeline call.
parallel_chunk_size: Optional chunk size for batched calls. If None or <= 0,
the full mapper batch size is used.
output_dir: Directory where generated images are written when output_mode is "path".
file_format: Image file format for saved outputs.
"""

pipeline_args: DiffusersPipelineArgs = field(default_factory=DiffusersPipelineArgs)
default_sampling_params: Dict[str, Any] = field(default_factory=dict)
parallel_inference: bool = True
parallel_chunk_size: Optional[int] = 4
output_dir: str = ".mmirage/generated_images"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes a new folder .mmirage at the root of the local repository?

file_format: str = "png"

def __post_init__(self) -> None:
"""Validate optional parallelism settings."""
if self.parallel_chunk_size is not None and self.parallel_chunk_size <= 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it sounds better to raise an error here, it should not be silently interpreted as None when a value is nonpositive

self.parallel_chunk_size = None

def get_output_dir(self) -> str:
"""Get normalized absolute output directory path."""
return os.path.abspath(os.path.expanduser(self.output_dir))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not in the cache folder?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir" in src/mmore/config/loading.py



@dataclass
class ImageGenOutputVar(OutputVar):
"""Output variable generated by image generation processor.

Attributes:
prompt: Jinja2 template used as positive prompt.
negative_prompt: Optional Jinja2 template used as negative prompt.
output_mode: Output representation: "path" (default) or "pil".
filename_template: Optional Jinja2 template used for saved image filename stem.
Supported internal variables: __sample_index (shard-global row index),
__output_name, __shard_id, __source_hash (8-char SHA-256 of input values).
All input variables (e.g. ``text``) are also available.
width: Optional image width override.
height: Optional image height override.
num_inference_steps: Optional sampling steps override.
guidance_scale: Optional guidance scale override.
seed: Optional deterministic seed. If set, sample index is added for uniqueness.
"""

prompt: str = ""
negative_prompt: str = ""
output_mode: ImageOutputMode = "path"
filename_template: str = "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}"
width: Optional[int] = None
height: Optional[int] = None
num_inference_steps: Optional[int] = None
guidance_scale: Optional[float] = None
seed: Optional[int] = None

def is_computable(self, vars: Sequence[BaseVar]) -> bool:
"""Check if all variables referenced in templates are available."""
reserved = {"__sample_index", "__output_name", "__shard_id", "__source_hash"}
var_names = {v.name for v in vars}

# Prompt/negative_prompt are rendered from env.to_dict() only — reserved
# vars are not injected there, so treat them as undeclared in those templates.
prompt_templates: List[str] = [self.prompt]
if self.negative_prompt:
prompt_templates.append(self.negative_prompt)

undeclared: set[str] = set()
for template in prompt_templates:
parsed_content = env.parse(template)
template_vars = meta.find_undeclared_variables(parsed_content)
undeclared |= template_vars - var_names

# filename_template is rendered with reserved vars injected, so allow them.
if self.filename_template:
parsed_content = env.parse(self.filename_template)
template_vars = meta.find_undeclared_variables(parsed_content)
undeclared |= template_vars - var_names - reserved

Comment thread
qchapp marked this conversation as resolved.
if undeclared:
logger.warning(
f"⚠️ Undeclared variables found for {self.name}: {undeclared}"
)
return False

if self.output_mode not in {"path", "pil"}:
logger.warning(
f"⚠️ Invalid output_mode for {self.name}: {self.output_mode}. Expected one of ['path', 'pil']"
)
return False

return True


ProcessorRegistry.register_types("image_gen", DiffusersImageGenConfig, ImageGenOutputVar)
Loading
Loading