Skip to content

Commit c5e75a2

Browse files
committed
suggestion by Copilot, mostly changing the CLI arguement to accept list of files
1 parent c7d4bac commit c5e75a2

4 files changed

Lines changed: 69 additions & 45 deletions

File tree

src/mmirage/core/process/batch/collector.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,26 @@
1313
from mmirage.core.process.batch.registry import BatchAdapterFactory
1414

1515

16-
def _read_metadata_records(metadata_output_path: str) -> List[Dict[str, Any]]:
16+
def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]:
17+
if isinstance(metadata_paths, str):
18+
return [metadata_paths]
19+
return [str(path) for path in metadata_paths]
20+
21+
22+
def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, Any]]:
1723
records: List[Dict[str, Any]] = []
18-
with open(metadata_output_path, "r", encoding="utf-8") as f:
19-
for line in f:
20-
raw = line.strip()
21-
if not raw:
22-
continue
23-
try:
24-
parsed = json.loads(raw)
25-
except json.JSONDecodeError:
26-
continue
27-
if isinstance(parsed, dict):
28-
records.append(parsed)
24+
for metadata_output_path in _normalize_metadata_paths(metadata_output_paths):
25+
with open(metadata_output_path, "r", encoding="utf-8") as f:
26+
for line in f:
27+
raw = line.strip()
28+
if not raw:
29+
continue
30+
try:
31+
parsed = json.loads(raw)
32+
except json.JSONDecodeError:
33+
continue
34+
if isinstance(parsed, dict):
35+
records.append(parsed)
2936
return records
3037

3138

@@ -56,7 +63,7 @@ def _aggregate_batch_mappings(
5663

5764

5865
def collect_and_merge(
59-
metadata_output_path: str,
66+
metadata_output_path: str | Sequence[str],
6067
provider_configs: Mapping[str, BatchProviderConfig],
6168
output_path: str,
6269
) -> List[Dict[str, Any]]:
@@ -161,8 +168,9 @@ def _build_arg_parser() -> argparse.ArgumentParser:
161168
)
162169
parser.add_argument(
163170
"--metadata-path",
171+
nargs="+",
164172
required=True,
165-
help="Path to metadata JSONL receipt file.",
173+
help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.",
166174
)
167175
parser.add_argument(
168176
"--output-path",

src/mmirage/core/process/batch/status_checker.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,55 @@
1414
from mmirage.core.process.batch.registry import BatchAdapterFactory
1515

1616

17-
def extract_unique_provider_batches(metadata_output_path: str) -> List[Tuple[str, str]]:
17+
def _normalize_metadata_paths(metadata_paths: str | Sequence[str]) -> List[str]:
18+
if isinstance(metadata_paths, str):
19+
return [metadata_paths]
20+
return [str(path) for path in metadata_paths]
21+
22+
23+
def _read_metadata_records(metadata_output_paths: str | Sequence[str]) -> List[Dict[str, str]]:
24+
records: List[Dict[str, str]] = []
25+
for metadata_output_path in _normalize_metadata_paths(metadata_output_paths):
26+
with open(metadata_output_path, "r", encoding="utf-8") as f:
27+
for line in f:
28+
raw = line.strip()
29+
if not raw:
30+
continue
31+
try:
32+
record = json.loads(raw)
33+
except json.JSONDecodeError:
34+
continue
35+
if isinstance(record, dict):
36+
records.append(record)
37+
return records
38+
39+
40+
def extract_unique_provider_batches(metadata_output_path: str | Sequence[str]) -> List[Tuple[str, str]]:
1841
"""Parse metadata JSONL and return unique ``(provider, provider_batch_id)`` pairs.
1942
2043
Malformed lines and records missing required keys are skipped safely.
2144
"""
2245
unique_pairs: List[Tuple[str, str]] = []
2346
seen = set()
2447

25-
with open(metadata_output_path, "r", encoding="utf-8") as f:
26-
for line in f:
27-
raw = line.strip()
28-
if not raw:
29-
continue
48+
for record in _read_metadata_records(metadata_output_path):
49+
provider = str(record.get("provider", "")).strip().lower()
50+
provider_batch_id = str(record.get("provider_batch_id", "")).strip()
3051

31-
try:
32-
record = json.loads(raw)
33-
except json.JSONDecodeError:
34-
continue
35-
36-
provider = str(record.get("provider", "")).strip().lower()
37-
provider_batch_id = str(record.get("provider_batch_id", "")).strip()
38-
39-
if not provider or not provider_batch_id:
40-
continue
52+
if not provider or not provider_batch_id:
53+
continue
4154

42-
pair = (provider, provider_batch_id)
43-
if pair in seen:
44-
continue
45-
seen.add(pair)
46-
unique_pairs.append(pair)
55+
pair = (provider, provider_batch_id)
56+
if pair in seen:
57+
continue
58+
seen.add(pair)
59+
unique_pairs.append(pair)
4760

4861
return unique_pairs
4962

5063

5164
def run_status_checker(
52-
metadata_output_path: str,
65+
metadata_output_path: str | Sequence[str],
5366
provider_configs: Mapping[str, BatchProviderConfig],
5467
output: TextIO = sys.stdout,
5568
) -> List[BatchSubmissionResult]:
@@ -84,7 +97,7 @@ def run_status_checker(
8497

8598

8699
def _build_provider_configs_from_metadata(
87-
metadata_output_path: str,
100+
metadata_output_path: str | Sequence[str],
88101
) -> Dict[str, BatchProviderConfig]:
89102
provider_names = {provider for provider, _ in extract_unique_provider_batches(metadata_output_path)}
90103
configs: Dict[str, BatchProviderConfig] = {}
@@ -104,8 +117,9 @@ def _build_arg_parser() -> argparse.ArgumentParser:
104117
parser = argparse.ArgumentParser(description="Check provider batch statuses from metadata receipts.")
105118
parser.add_argument(
106119
"--metadata-path",
120+
nargs="+",
107121
required=True,
108-
help="Path to metadata JSONL receipt file.",
122+
help="Path(s) to metadata JSONL receipt file(s). Supports multiple files.",
109123
)
110124
return parser
111125

src/mmirage/core/process/processors/llm/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,19 @@
1818

1919
def _parse_tp_size_from_env() -> int:
2020
"""Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable.
21-
21+
2222
Defensively parses the environment variable, handling invalid values:
2323
- Returns 1 if the variable is None or empty
2424
- Strips whitespace before parsing
2525
- Returns 1 for non-integer values
2626
- Returns 1 for values <= 0
27-
27+
2828
Returns:
2929
Tensor parallelism size (>= 1), defaults to 1 on any parsing error.
3030
"""
3131
env_value = os.environ.get("SLURM_GPUS_ON_NODE")
3232
if not env_value:
3333
return 1
34-
3534
try:
3635
tp_size = int(env_value.strip())
3736
# Ensure tp_size is positive (must be >= 1)

src/mmirage/core/process/processors/llm/llm_processor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None:
6363
"""
6464
super().__init__(engine_args, **kwargs)
6565
provider_cfg_raw = dict(getattr(engine_args, "batch_provider", {}) or {})
66-
batch_mode_requested = bool(provider_cfg_raw.get("enabled", False))
66+
batch_mode_requested = bool(provider_cfg_raw.get("enabled", True))
6767

6868
# In provider-batch mode we only build payloads/metadata and should not
6969
# initialize GPU-backed SGLang runtime.
@@ -82,14 +82,15 @@ def __init__(self, engine_args: SGLangLLMConfig, **kwargs) -> None:
8282
self._text_orchestrator: Optional[BatchSubmissionOrchestrator] = None
8383
self._multimodal_orchestrator: Optional[BatchSubmissionOrchestrator] = None
8484
self._batch_request_counter = 0
85+
self._global_row_offset = 0
8586
self._setup_batch_runtime()
8687

8788
def _setup_batch_runtime(self) -> None:
8889
provider_cfg_raw = dict(getattr(self.config, "batch_provider", {}) or {})
8990
if not provider_cfg_raw:
9091
return
9192

92-
if not provider_cfg_raw.get("enabled", False):
93+
if not provider_cfg_raw.get("enabled", True):
9394
return
9495

9596
provider = str(provider_cfg_raw.get("provider", "openai")).strip().lower()
@@ -390,7 +391,7 @@ def _batch_process_sample(
390391
config=self._batch_provider_config,
391392
)
392393
requests.append(dict(request))
393-
source_indices.append(self._batch_request_counter)
394+
source_indices.append(self._global_row_offset + global_i)
394395

395396
self._text_orchestrator.add_requests(
396397
requests=requests,
@@ -440,7 +441,7 @@ def _batch_process_sample(
440441
config=self._batch_provider_config,
441442
)
442443
requests.append(dict(request))
443-
source_indices.append(self._batch_request_counter)
444+
source_indices.append(self._global_row_offset + global_i)
444445

445446
self._multimodal_orchestrator.add_requests(
446447
requests=requests,
@@ -458,6 +459,8 @@ def _batch_process_sample(
458459
placeholder = f"__BATCH_SUBMITTED__:{unique_id}"
459460
placeholders.append(batch[i].with_variable(output_var.name, placeholder))
460461

462+
self._global_row_offset += nb_samples
463+
461464
return placeholders
462465

463466
def finalize(self) -> None:

0 commit comments

Comments
 (0)