Skip to content

Commit fdb0d42

Browse files
committed
fixed various typing and logic errors
1 parent c962618 commit fdb0d42

5 files changed

Lines changed: 26 additions & 10 deletions

File tree

src/mmirage/config/loading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def is_unresolved_env_var(s: str) -> bool:
4747
if self.num_shards < 1:
4848
raise ValueError()
4949
except (ValueError, TypeError):
50-
if is_unresolved_env_var(self.num_shards):
50+
if isinstance(self.num_shards, str) and is_unresolved_env_var(self.num_shards):
5151
self.num_shards = 1
5252
else:
5353
raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}")
@@ -56,7 +56,7 @@ def is_unresolved_env_var(s: str) -> bool:
5656
try:
5757
self.shard_id = int(self.shard_id)
5858
except (ValueError, TypeError):
59-
if is_unresolved_env_var(self.shard_id):
59+
if isinstance(self.shard_id, str) and is_unresolved_env_var(self.shard_id):
6060
self.shard_id = 0
6161
else:
6262
raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}")

src/mmirage/core/process/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def batch_process_sample(
7272
"""
7373
raise NotImplementedError()
7474

75-
@abstract
75+
@abc.abstractmethod
7676
def get_token_counts(self) -> TokenCounts:
7777
"""Get cumulative token counts from this processor.
7878
@@ -84,6 +84,18 @@ def get_token_counts(self) -> TokenCounts:
8484
"""
8585
raise NotImplementedError()
8686

87+
@abc.abstractmethod
88+
def get_load_time(self) -> float:
89+
"""Get the time taken to load any necessary resources (e.g., models).
90+
91+
Returns:
92+
Time in seconds taken to load resources.
93+
94+
Raises:
95+
NotImplementedError: If not implemented by subclass.
96+
"""
97+
raise NotImplementedError()
98+
8799

88100
class ProcessorRegistry:
89101
"""Registry for managing and accessing available processors.

src/mmirage/core/process/mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def validate_vars(self) -> bool:
7575
def rewrite_batch(
7676
self,
7777
batch: Dict[str, List[Any]],
78-
image_base_path: str = None,
78+
image_base_path: Optional[str] = None,
7979
) -> List[VariableEnvironment]:
8080
"""Transform a batch of samples by computing output variables.
8181

src/mmirage/merge_shards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _merge_datasetdict(shard_dsets: List[DatasetDict]) -> DatasetDict:
4949
merged[str(split)] = concatenate_datasets(split_dsets)
5050
if not merged:
5151
raise RuntimeError("All splits were empty after merging.")
52-
return DatasetDict(merged)
52+
return DatasetDict(**merged)
5353

5454

5555
def _merge_shards(shard_dsets: List[DatasetLike]) -> DatasetLike:

src/mmirage/shard_process.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def rewrite_batch(
3737
batch: Dict[str, List[Any]],
3838
mapper: MMIRAGEMapper,
3939
renderer: TemplateRenderer,
40-
image_base_path: str = None,
40+
image_base_path: Optional[str] = None,
4141
) -> Dict[str, List[Any]]:
4242
"""Rewrite a batch of samples by applying transformations.
4343
Args:
@@ -91,6 +91,8 @@ def main():
9191

9292
state_dir = shard_state_dir(shard_id, loading_params.get_state_root())
9393

94+
gpu_poller: Optional[GpuUtilizationPoller] = None
95+
9496
collect_stats = os.environ.get("MMIRAGE_COLLECT_STATS", "") == "1"
9597
if collect_stats:
9698
# Determine which physical GPU indices SGLang will use so the poller
@@ -112,9 +114,11 @@ def main():
112114
gpu_indices_for_polling: List[str] = all_visible[:tp_size] if all_visible else [str(i) for i in range(tp_size)]
113115
else:
114116
gpu_indices_for_polling = [str(i) for i in range(tp_size)]
115-
gpu_poller: GpuUtilizationPoller = GpuUtilizationPoller(
117+
118+
gpu_poller = GpuUtilizationPoller(
116119
interval_seconds=5.0, gpu_indices=gpu_indices_for_polling
117120
)
121+
118122
try:
119123
retry_count = _mark_running(state_dir, shard_id, datasets_config)
120124
logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})")
@@ -144,7 +148,7 @@ def main():
144148

145149
# Start GPU polling after model loading so utilisation samples reflect
146150
# inference only, not weight transfers during sgl.Engine() init.
147-
if collect_stats:
151+
if collect_stats and gpu_poller is not None:
148152
gpu_poller.start()
149153

150154
ds_processed_all: List[DatasetLike] = []
@@ -180,7 +184,7 @@ def main():
180184
_save_dataset_atomic(ds_processed, out_dir)
181185
logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}")
182186

183-
gpu_info = gpu_poller.stop() if collect_stats else {"mean": None, "min": None, "max": None, "samples": 0}
187+
gpu_info = gpu_poller.stop() if collect_stats and gpu_poller is not None else {"mean": None, "min": None, "max": None, "samples": 0}
184188

185189
# Collect token counts accumulated by LLM processor(s).
186190
token_counts = mapper.get_token_counts()
@@ -214,7 +218,7 @@ def main():
214218
error_msg = f"{type(e).__name__}: {str(e)}"
215219
logger.error(f"❌ Shard {shard_id} failed: {error_msg}")
216220
logger.error(traceback.format_exc())
217-
if collect_stats:
221+
if collect_stats and gpu_poller is not None:
218222
gpu_poller.stop()
219223
_mark_failure(state_dir, error_msg)
220224
sys.exit(1)

0 commit comments

Comments
 (0)