-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathbase.py
More file actions
228 lines (172 loc) · 7.1 KB
/
base.py
File metadata and controls
228 lines (172 loc) · 7.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Base classes and registry for processors in MMIRAGE."""
import abc
from importlib import import_module
from dataclasses import dataclass
from typing import Callable, Generic, List, Type, TypeVar
from mmirage.core.process.variables import VariableEnvironment, OutputVar
@dataclass
class BaseProcessorConfig:
"""Base configuration class for processors.
All processor configurations must inherit from this class.
Attributes:
type: String identifier for the processor type (e.g., "llm").
"""
type: str = ""
C = TypeVar("C", bound=OutputVar)
class BaseProcessor(abc.ABC, Generic[C]):
"""Abstract base class for data processors.
Processors are responsible for transforming data by generating
new output variables from existing variables.
Type Parameters:
C: The output variable type this processor works with.
Attributes:
config: Configuration object for this processor.
"""
def __init__(self, config: BaseProcessorConfig, shard_id: int = 0, **kwargs) -> None:
"""Initialize the processor with configuration.
Args:
config: Configuration object for this processor.
shard_id: Optional shard identifier accepted for compatibility
with callers that forward it during processor construction.
**kwargs: Additional keyword arguments. Any unexpected keyword
arguments will raise ``TypeError``.
Raises:
TypeError: If unexpected keyword arguments are provided.
"""
if kwargs:
unexpected_args = ", ".join(sorted(kwargs))
raise TypeError(
f"Unexpected keyword argument(s) for "
f"{self.__class__.__name__}: {unexpected_args}"
)
super().__init__()
self.config = config
self.shard_id = shard_id
@abc.abstractmethod
def batch_process_sample(
self, batch: List[VariableEnvironment], output_var: C
) -> List[VariableEnvironment]:
"""Process a batch of variable environments.
Args:
batch: List of variable environments to process.
output_var: Output variable definition to generate.
Returns:
List of updated variable environments with the new output variable.
Raises:
NotImplementedError: If not implemented by subclass.
"""
raise NotImplementedError()
class ProcessorRegistry:
"""Registry for managing and accessing available processors.
Provides a centralized registry for processor classes, their
configuration classes, and their output variable classes.
Attributes:
_registry: Mapping from processor name to registered processor class.
_config_registry: Mapping from processor name to its configuration class.
_output_var_registry: Mapping from processor name to its output variable class.
"""
_registry = dict()
_config_registry = dict()
_output_var_registry = dict()
# Import processor implementations lazily because they may depend on heavy
# libraries (torch/transformers). Config/output-var types are registered via
# mmirage.config.utils importing the relevant config modules.
_lazy_processor_imports = {
"llm": "mmirage.core.process.processors.llm.llm_processor",
"image_gen": "mmirage.core.process.processors.image_gen.image_gen_processor",
}
@classmethod
def register_types(
cls,
name: str,
config_cls: Type[BaseProcessorConfig],
output_var_cls: Type[OutputVar],
) -> None:
"""Register config/output-var types without importing processor implementations."""
cls._config_registry[name] = config_cls
cls._output_var_registry[name] = output_var_cls
@classmethod
def _maybe_import_processor(cls, name: str) -> None:
module = cls._lazy_processor_imports.get(name)
if module:
import_module(module)
@classmethod
def register(
cls,
name: str,
config_cls: Type[BaseProcessorConfig],
output_var_cls: Type[OutputVar],
) -> Callable:
"""Register a processor class with its associated classes.
Args:
name: String identifier for the processor.
config_cls: Configuration class associated with this processor.
output_var_cls: Output variable class associated with this processor.
Returns:
Decorator function to register the processor class.
"""
def inner_register(clazz):
cls._registry[name] = clazz
cls._config_registry[name] = config_cls
cls._output_var_registry[name] = output_var_cls
return inner_register
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""Get a registered processor class by name.
Args:
name: String identifier of the processor.
Returns:
The registered processor class.
Raises:
ValueError: If no processor is registered under the given name.
"""
if name not in cls._registry:
cls._maybe_import_processor(name)
if name not in cls._registry:
raise ValueError(
f"Processor {name} not registered. Available processors are {list(cls._registry.keys())}"
)
return cls._registry[name]
@classmethod
def get_config_cls(cls, name: str) -> Type[BaseProcessorConfig]:
"""Get a registered configuration class by processor name.
Args:
name: String identifier of the processor.
Returns:
The registered configuration class.
Raises:
ValueError: If no processor is registered under the given name.
"""
if name not in cls._config_registry:
raise ValueError(
f"Processor {name} not registered. Available processors are {list(cls._config_registry.keys())}"
)
return cls._config_registry[name]
@classmethod
def get_output_var_cls(cls, name: str) -> Type[OutputVar]:
"""Get a registered output variable class by processor name.
Args:
name: String identifier of the processor.
Returns:
The registered output variable class.
Raises:
ValueError: If no processor is registered under the given name.
"""
if name not in cls._output_var_registry:
raise ValueError(
f"Processor {name} not registered. Available processors are {list(cls._output_var_registry.keys())}"
)
return cls._output_var_registry[name]
class AutoProcessor:
"""Factory class for instantiating processors by name."""
@classmethod
def from_name(cls, name: str) -> Type[BaseProcessor]:
"""Retrieve a processor class by its registered name.
Args:
name: The registry name of the processor.
Returns:
The registered processor class.
Raises:
ValueError: If no processor is registered under the given name.
"""
return ProcessorRegistry.get_processor(name)