-
Notifications
You must be signed in to change notification settings - Fork 3
Feature/generation model #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
09d96bf
9bc4ada
48a3e39
e374be4
fb6454b
e3b1a82
c3f85ae
b840515
3368040
534d8c6
89a81a3
c46a01d
763f756
544ca0c
8dfa161
e0481ae
3205ece
691f7d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| processors: | ||
| - type: image_gen | ||
| pipeline_args: | ||
| model_path: stable-diffusion-v1-5/stable-diffusion-v1-5 | ||
| torch_dtype: float16 | ||
| device: auto | ||
| enable_attention_slicing: true | ||
| default_sampling_params: | ||
| num_inference_steps: 20 | ||
| guidance_scale: 7.5 | ||
| parallel_inference: true | ||
| parallel_chunk_size: 4 | ||
| output_dir: tests/output/image_gen/generated_images | ||
| file_format: png | ||
|
|
||
| loading_params: | ||
| state_dir: tests/output/image_gen/_pipeline_state | ||
| datasets: | ||
| - path: tests/mock_data_image_gen/data.jsonl | ||
| type: JSONL | ||
| output_dir: tests/output/image_gen | ||
|
|
||
| num_shards: 1 | ||
| shard_id: 0 | ||
| batch_size: 4 | ||
|
|
||
| processing_params: | ||
| inputs: | ||
| - name: text | ||
| key: text | ||
|
|
||
| outputs: | ||
| - name: generated_image | ||
| type: image_gen | ||
| output_mode: path | ||
| filename_template: "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" | ||
| width: 512 | ||
| height: 512 | ||
| prompt: | | ||
| Create a clean and detailed illustration for: | ||
| {{ text }} | ||
|
|
||
| remove_columns: false | ||
| output_schema: | ||
| prompt_source: "{{ text }}" | ||
| image: "{{ generated_image }}" | ||
|
|
||
| execution_params: | ||
| mode: local | ||
| retry: false | ||
| merge: false | ||
| report_dir: ~/reports | ||
| hf_home: ~/hf | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,13 +29,15 @@ def __init__( | |
| processor_configs: List[BaseProcessorConfig], | ||
| input_vars: List[InputVar], | ||
| output_vars: List[OutputVar], | ||
| shard_id: int = 0, | ||
| ) -> None: | ||
| """Initialize the MMIRAGE mapper. | ||
|
|
||
| Args: | ||
| processor_configs: List of processor configurations. | ||
| input_vars: List of input variable definitions. | ||
| output_vars: List of output variable definitions. | ||
| shard_id: Shard index for this worker, forwarded to processors. | ||
| """ | ||
| self.processors: Dict[str, BaseProcessor] = dict() | ||
| self.input_vars = input_vars | ||
|
|
@@ -45,7 +47,7 @@ def __init__( | |
| processor_cls = AutoProcessor.from_name(config.type) | ||
| logger.info(f"✅ Successfully loaded processor of type {config.type}") | ||
|
|
||
| self.processors[config.type] = processor_cls(config) | ||
| self.processors[config.type] = processor_cls(config, shard_id=shard_id) | ||
|
qchapp marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the shard_id is currently ignored by LLMProcessor, maybe make it use it as well? it seems to be used only for computing the render filename |
||
|
|
||
| def validate_vars(self) -> bool: | ||
| """Validate that all output variables are computable. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| """Image generation processor implementation. | ||
|
|
||
| This module provides a dedicated processor for text-to-image generation tasks | ||
| using Diffusers pipelines. It can emit either saved image paths or in-memory | ||
| PIL images. | ||
|
Comment on lines
+4
to
+5
|
||
| """ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| """Configuration for image generation processor in MMIRAGE.""" | ||
|
|
||
| from dataclasses import dataclass, field | ||
|
|
||
| import logging | ||
| import os | ||
| from typing import Any, Dict, List, Literal, Optional, Sequence, TypeAlias | ||
| from jinja2 import Environment, meta | ||
|
|
||
| from mmirage.core.process.base import BaseProcessorConfig | ||
| from mmirage.core.process.base import ProcessorRegistry | ||
| from mmirage.core.process.variables import BaseVar, OutputVar | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| env = Environment() | ||
|
|
||
| ImageOutputMode: TypeAlias = Literal["path", "pil"] | ||
|
|
||
|
|
||
| @dataclass | ||
| class DiffusersPipelineArgs: | ||
| """Runtime arguments used to initialize a Diffusers pipeline. | ||
|
|
||
| Attributes: | ||
| model_path: Hugging Face model id or local path. | ||
| revision: Optional model revision. | ||
| torch_dtype: Torch dtype as string. Common values: "float16", "bfloat16", "float32", "auto". | ||
| device: Target device: "auto", "cuda", "cpu", or explicit device string. | ||
| enable_attention_slicing: Enable attention slicing when available to reduce VRAM usage. | ||
| """ | ||
|
|
||
| model_path: str = "stable-diffusion-v1-5/stable-diffusion-v1-5" | ||
| revision: Optional[str] = None | ||
| torch_dtype: str = "float16" | ||
| device: str = "auto" | ||
| enable_attention_slicing: bool = True | ||
|
|
||
|
|
||
| @dataclass | ||
| class DiffusersImageGenConfig(BaseProcessorConfig): | ||
| """Configuration for image generation processor. | ||
|
|
||
| Attributes: | ||
| pipeline_args: Arguments used to initialize the Diffusers pipeline. | ||
| default_sampling_params: Default generation kwargs passed to pipeline calls. | ||
| parallel_inference: If True, process batch samples in parallel via a single batched pipeline call. | ||
| parallel_chunk_size: Optional chunk size for batched calls. If None or <= 0, | ||
| the full mapper batch size is used. | ||
| output_dir: Directory where generated images are written when output_mode is "path". | ||
| file_format: Image file format for saved outputs. | ||
| """ | ||
|
|
||
| pipeline_args: DiffusersPipelineArgs = field(default_factory=DiffusersPipelineArgs) | ||
| default_sampling_params: Dict[str, Any] = field(default_factory=dict) | ||
| parallel_inference: bool = True | ||
| parallel_chunk_size: Optional[int] = 4 | ||
| output_dir: str = ".mmirage/generated_images" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes a new folder .mmirage at the root of the local repository? |
||
| file_format: str = "png" | ||
|
|
||
| def __post_init__(self) -> None: | ||
| """Validate optional parallelism settings.""" | ||
| if self.parallel_chunk_size is not None and self.parallel_chunk_size <= 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it sounds better to raise an error here, it should not be silently interpreted as None when a value is nonpositive |
||
| self.parallel_chunk_size = None | ||
|
|
||
| def get_output_dir(self) -> str: | ||
| """Get normalized absolute output directory path.""" | ||
| return os.path.abspath(os.path.expanduser(self.output_dir)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not in the cache folder?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like |
||
|
|
||
|
|
||
| @dataclass | ||
| class ImageGenOutputVar(OutputVar): | ||
| """Output variable generated by image generation processor. | ||
|
|
||
| Attributes: | ||
| prompt: Jinja2 template used as positive prompt. | ||
| negative_prompt: Optional Jinja2 template used as negative prompt. | ||
| output_mode: Output representation: "path" (default) or "pil". | ||
| filename_template: Optional Jinja2 template used for saved image filename stem. | ||
| Supported internal variables: __sample_index (shard-global row index), | ||
| __output_name, __shard_id, __source_hash (8-char SHA-256 of input values). | ||
| All input variables (e.g. ``text``) are also available. | ||
| width: Optional image width override. | ||
| height: Optional image height override. | ||
| num_inference_steps: Optional sampling steps override. | ||
| guidance_scale: Optional guidance scale override. | ||
| seed: Optional deterministic seed. If set, sample index is added for uniqueness. | ||
| """ | ||
|
|
||
| prompt: str = "" | ||
| negative_prompt: str = "" | ||
| output_mode: ImageOutputMode = "path" | ||
| filename_template: str = "generated_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}" | ||
| width: Optional[int] = None | ||
| height: Optional[int] = None | ||
| num_inference_steps: Optional[int] = None | ||
| guidance_scale: Optional[float] = None | ||
| seed: Optional[int] = None | ||
|
|
||
| def is_computable(self, vars: Sequence[BaseVar]) -> bool: | ||
| """Check if all variables referenced in templates are available.""" | ||
| reserved = {"__sample_index", "__output_name", "__shard_id", "__source_hash"} | ||
| var_names = {v.name for v in vars} | ||
|
|
||
| templates: List[str] = [self.prompt] | ||
| if self.negative_prompt: | ||
| templates.append(self.negative_prompt) | ||
| if self.filename_template: | ||
| templates.append(self.filename_template) | ||
|
|
||
| undeclared: set[str] = set() | ||
| for template in templates: | ||
| parsed_content = env.parse(template) | ||
| template_vars = meta.find_undeclared_variables(parsed_content) | ||
| undeclared |= template_vars - var_names - reserved | ||
|
|
||
|
qchapp marked this conversation as resolved.
|
||
| if undeclared: | ||
| logger.warning( | ||
| f"⚠️ Undeclared variables found for {self.name}: {undeclared}" | ||
| ) | ||
| return False | ||
|
|
||
| if self.output_mode not in {"path", "pil"}: | ||
| logger.warning( | ||
| f"⚠️ Invalid output_mode for {self.name}: {self.output_mode}. Expected one of ['path', 'pil']" | ||
| ) | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| ProcessorRegistry.register_types("image_gen", DiffusersImageGenConfig, ImageGenOutputVar) | ||
Uh oh!
There was an error while loading. Please reload this page.