Skip to content

Commit e0481ae

Browse files
committed
new backends to test on cluster
1 parent 8dfa161 commit e0481ae

12 files changed

Lines changed: 1559 additions & 276 deletions
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Local Diffusers image generation example
2+
# Runs an in-process Diffusers pipeline from a local model path.
3+
4+
processors:
5+
- type: image_gen
6+
backend: diffusers
7+
pipeline_args:
8+
model_path: stable-diffusion-v1-5/stable-diffusion-v1-5
9+
device: auto # auto: multi-GPU via device_map, single GPU, or CPU
10+
torch_dtype: float16
11+
enable_attention_slicing: false
12+
default_sampling_params:
13+
num_inference_steps: 30
14+
guidance_scale: 4.0
15+
parallel_inference: true
16+
parallel_chunk_size: 4 # samples per batched GPU call
17+
output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/generated_images
18+
file_format: png
19+
20+
loading_params:
21+
state_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/_pipeline_state
22+
datasets:
23+
- path: /users/qchapp/meditron/MIRAGE/tests/mock_data_image_gen/data.jsonl
24+
type: JSONL
25+
output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen
26+
num_shards: 4 # set by Slurm array job
27+
shard_id: ${SLURM_ARRAY_TASK_ID} # set by Slurm array job
28+
batch_size: 64
29+
30+
processing_params:
31+
inputs:
32+
- name: text
33+
key: caption
34+
35+
outputs:
36+
- name: generated_image
37+
type: image_gen
38+
output_mode: path
39+
# __sample_index is shard-local; combine with __shard_id for global uniqueness
40+
filename_template: "img_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}"
41+
width: 1024
42+
height: 1024
43+
seed: 42 # shard-aware: effective = 42 + shard_id * 1_000_000_000 + sample_index
44+
prompt: |
45+
A photorealistic image of: {{ text }}
46+
47+
remove_columns: false
48+
output_schema:
49+
caption: "{{ text }}"
50+
image: "{{ generated_image }}"
51+
52+
execution_params:
53+
# Execution mode: "local" or "slurm"
54+
# - local: Run directly on this machine
55+
# - slurm: Submit jobs to SLURM cluster
56+
mode: slurm
57+
58+
# Whether the canonical `run` command should automatically retry failed shards.
59+
# - false: submit one run only
60+
# - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion
61+
retry: false
62+
63+
# Maximum number of times to retry a failed shard (default: 3)
64+
max_retries: 3
65+
66+
# ==========================================================================
67+
# SLURM CONFIGURATION (only used when mode: slurm)
68+
# ==========================================================================
69+
70+
# HPC account/partition to charge jobs to (REQUIRED for SLURM mode)
71+
account: a127
72+
73+
# SLURM job name (default: "mmirage-sharded")
74+
job_name: mmirage-sharded
75+
76+
# Optional SLURM reservation name (leave blank or omit to not use)
77+
# reservation: "sai-a127"
78+
79+
# Number of nodes (default: 1)
80+
nodes: 1
81+
82+
# Number of tasks per node (default: 1)
83+
ntasks_per_node: 1
84+
85+
# Number of GPUs per node (default: 4)
86+
gpus: 4
87+
88+
# Number of CPUs per task (default: 288)
89+
cpus_per_task: 288
90+
91+
# Job time limit in HH:MM:SS format (default: "11:59:59")
92+
time_limit: "11:59:59"
93+
94+
# ==========================================================================
95+
# PATH CONFIGURATION
96+
# ==========================================================================
97+
# These support environment variables ($VAR or ${VAR}) and home directory (~)
98+
99+
# Project root directory (used as base for relative paths)
100+
# If not set, uses current working directory
101+
# project_root: "/path/to/project"
102+
103+
# Directory for SLURM output and error files (default: ~/reports)
104+
report_dir: "/users/${USER}/reports"
105+
106+
# HuggingFace cache directory (default: ~/hf)
107+
hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf"
108+
109+
# EDF environment file path for cluster-specific setup
110+
edf_env: "/users/${USER}/.edf/sglang.toml"
111+
112+
# ==========================================================================
113+
# JOB MONITORING (for "submit" and retry orchestration)
114+
# ==========================================================================
115+
116+
# Seconds to wait between checking job status (default: 30)
117+
poll_interval_seconds: 30
118+
119+
# Seconds to wait after job completes before checking results (default: 60)
120+
# This allows filesystem to settle on distributed systems
121+
settle_time_seconds: 60
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# SGLang Diffusion server image generation example
2+
#
3+
# Two launch modes are available:
4+
#
5+
# launch_mode: managed (recommended)
6+
# MMIRAGE automatically starts a SGLang server on each worker node,
7+
# waits until it is ready, runs the pipeline, and shuts it down afterwards.
8+
# No manual server management is needed.
9+
#
10+
# launch_mode: external
11+
# You are responsible for starting the SGLang server before the pipeline
12+
# runs. Use this if you want to reuse a long-running server or need
13+
# fine-grained control over server startup.
14+
#
15+
# MMIRAGE handles all dataset sharding, prompt rendering, filename rendering,
16+
# and result saving. The SGLang server is only responsible for generating
17+
# image pixels.
18+
19+
processors:
20+
- type: image_gen
21+
backend: sglang
22+
sglang:
23+
launch_mode: managed # MMIRAGE starts/stops the server automatically
24+
model_path: stable-diffusion-v1-5/stable-diffusion-v1-5
25+
port: 30010
26+
num_gpus: 4 # passed as --tp to sglang.launch_server
27+
# dtype: float16 # optional: --dtype flag
28+
startup_timeout_seconds: 180 # seconds to wait for the server to become ready
29+
# extra_server_args: # any additional --flags for sglang.launch_server
30+
# - "--mem-fraction-static"
31+
# - "0.9"
32+
api_key: EMPTY # unauthenticated local server
33+
timeout_seconds: 900
34+
default_sampling_params:
35+
num_inference_steps: 30
36+
guidance_scale: 4.0
37+
parallel_inference: true
38+
parallel_chunk_size: 4 # concurrent requests per chunk (sequential per sample inside the backend)
39+
output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/generated_images
40+
file_format: png
41+
42+
loading_params:
43+
state_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen/_pipeline_state
44+
datasets:
45+
- path: /users/qchapp/meditron/MIRAGE/tests/mock_data_image_gen/data.jsonl
46+
type: JSONL
47+
output_dir: /users/qchapp/meditron/MIRAGE/tests/output/image_gen
48+
num_shards: 4 # each Slurm task starts its own server on localhost
49+
shard_id: ${SLURM_ARRAY_TASK_ID}
50+
batch_size: 64
51+
52+
processing_params:
53+
inputs:
54+
- name: text
55+
key: caption
56+
57+
outputs:
58+
- name: generated_image
59+
type: image_gen
60+
output_mode: path
61+
filename_template: "img_{{ __shard_id }}_{{ __sample_index }}_{{ __source_hash }}"
62+
width: 1024
63+
height: 1024
64+
seed: 42 # shard-aware: effective seed = 42 + shard_id * 1_000_000_000 + sample_index
65+
prompt: |
66+
A photorealistic image of: {{ text }}
67+
68+
remove_columns: false
69+
output_schema:
70+
caption: "{{ text }}"
71+
image: "{{ generated_image }}"
72+
73+
execution_params:
74+
# Execution mode: "local" or "slurm"
75+
# - local: Run directly on this machine
76+
# - slurm: Submit jobs to SLURM cluster
77+
mode: slurm
78+
79+
# Whether the canonical `run` command should automatically retry failed shards.
80+
# - false: submit one run only
81+
# - true: submit, wait, and keep retrying failed shards until success or retry budget exhaustion
82+
retry: false
83+
84+
# Maximum number of times to retry a failed shard (default: 3)
85+
max_retries: 3
86+
87+
# ==========================================================================
88+
# SLURM CONFIGURATION (only used when mode: slurm)
89+
# ==========================================================================
90+
91+
# HPC account/partition to charge jobs to (REQUIRED for SLURM mode)
92+
account: a127
93+
94+
# SLURM job name (default: "mmirage-sharded")
95+
job_name: mmirage-sharded
96+
97+
# Optional SLURM reservation name (leave blank or omit to not use)
98+
# reservation: "sai-a127"
99+
100+
# Number of nodes (default: 1)
101+
nodes: 1
102+
103+
# Number of tasks per node (default: 1)
104+
ntasks_per_node: 1
105+
106+
# Number of GPUs per node (default: 4)
107+
gpus: 4
108+
109+
# Number of CPUs per task (default: 288)
110+
cpus_per_task: 288
111+
112+
# Job time limit in HH:MM:SS format (default: "11:59:59")
113+
time_limit: "11:59:59"
114+
115+
# ==========================================================================
116+
# PATH CONFIGURATION
117+
# ==========================================================================
118+
# These support environment variables ($VAR or ${VAR}) and home directory (~)
119+
120+
# Project root directory (used as base for relative paths)
121+
# If not set, uses current working directory
122+
# project_root: "/path/to/project"
123+
124+
# Directory for SLURM output and error files (default: ~/reports)
125+
report_dir: "/users/${USER}/reports"
126+
127+
# HuggingFace cache directory (default: ~/hf)
128+
hf_home: "/capstor/store/cscs/swissai/a127/homes/${USER}/hf"
129+
130+
# EDF environment file path for cluster-specific setup
131+
edf_env: "/users/${USER}/.edf/sglang.toml"
132+
133+
# ==========================================================================
134+
# JOB MONITORING (for "submit" and retry orchestration)
135+
# ==========================================================================
136+
137+
# Seconds to wait between checking job status (default: 30)
138+
poll_interval_seconds: 30
139+
140+
# Seconds to wait after job completes before checking results (default: 60)
141+
# This allows filesystem to settle on distributed systems
142+
settle_time_seconds: 60
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
processors:
22
- type: image_gen
3+
backend: diffusers
34
pipeline_args:
45
model_path: stable-diffusion-v1-5/stable-diffusion-v1-5
56
torch_dtype: float16
67
device: auto
7-
enable_attention_slicing: true
8+
enable_attention_slicing: false
89
default_sampling_params:
910
num_inference_steps: 20
1011
guidance_scale: 7.5

src/mmirage/core/process/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ def batch_process_sample(
7878
"""
7979
raise NotImplementedError()
8080

81+
def shutdown(self) -> None:
82+
"""Release any resources held by this processor.
83+
84+
Override in subclasses that hold GPU memory, open file handles, or
85+
network connections. The default implementation is a no-op.
86+
"""
87+
8188

8289
class ProcessorRegistry:
8390
"""Registry for managing and accessing available processors.

src/mmirage/core/process/mapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,8 @@ def rewrite_batch(
105105
)
106106

107107
return batch_environment
108+
109+
def shutdown(self) -> None:
110+
"""Shut down all processors and release their resources."""
111+
for processor in self.processors.values():
112+
processor.shutdown()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Image generation backends for MMIRAGE."""
2+
3+
from mmirage.core.process.processors.image_gen.backends.base import ImageGenerationBackend
4+
5+
__all__ = ["ImageGenerationBackend"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Image generation backend protocol for MMIRAGE."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Dict, List, Optional
6+
7+
try:
8+
from typing import Protocol, runtime_checkable
9+
except ImportError: # pragma: no cover
10+
from typing_extensions import Protocol, runtime_checkable # type: ignore
11+
12+
13+
@runtime_checkable
14+
class ImageGenerationBackend(Protocol):
15+
"""Protocol for pluggable image generation backends.
16+
17+
All backends receive pre-rendered prompts and pre-computed per-sample seeds
18+
from the processor. The processor handles all Jinja template rendering,
19+
filename generation, and result bookkeeping; the backend is responsible
20+
only for turning prompts + params into PIL images.
21+
"""
22+
23+
def generate_batch(
24+
self,
25+
prompts: List[str],
26+
negative_prompts: Optional[List[Optional[str]]],
27+
params: Dict[str, Any],
28+
seeds: List[Optional[int]],
29+
) -> List[Any]:
30+
"""Generate one image per prompt.
31+
32+
Args:
33+
prompts: Positive prompt strings, one per sample.
34+
negative_prompts: Optional list of negative prompts aligned with
35+
``prompts``. ``None`` means no negative prompts at all;
36+
individual ``None`` elements mean no negative prompt for that
37+
sample.
38+
params: Shared generation kwargs (width, height,
39+
num_inference_steps, guidance_scale, …).
40+
seeds: Per-sample integer seeds for deterministic generation, or
41+
``None`` elements for unseeded samples. The list is always
42+
the same length as ``prompts``.
43+
44+
Returns:
45+
List of ``PIL.Image`` objects, one per prompt, in the same order.
46+
"""
47+
...
48+
49+
def shutdown(self) -> None:
50+
"""Release any resources held by the backend."""
51+
...

0 commit comments

Comments
 (0)