Skip to content

Commit 5e98d72

Browse files
committed
rf(indexing): thread schema through batch_index_dataset
1 parent 7381067 commit 5e98d72

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

bids2table/_indexing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,25 +278,32 @@ def batch_index_dataset(
278278
max_workers: int | None = 0,
279279
executor_cls: type[Executor] = ProcessPoolExecutor,
280280
show_progress: bool = False,
281+
schema: BIDSSchema | pa.Schema | Namespace | str | Path | None = None,
281282
) -> Generator[pa.Table, None, None]:
282283
"""Index a batch of BIDS datasets.
283284
284285
Args:
285286
roots: List of BIDS dataset root directories.
286287
max_workers: Number of indexing processes to run in parallel. Setting
287288
`max_workers=0` (the default) uses the main process only. Setting
288-
`max_workers=None` starts as many workers as there are available CPUs. See
289-
`concurrent.futures.ProcessPoolExecutor` for details.
289+
`max_workers=None` starts as many workers as there are available CPUs.
290+
See `concurrent.futures.ProcessPoolExecutor` for details.
290291
executor_cls: Executor class to use for parallel indexing.
291292
show_progress: Show progress bar.
293+
schema: A `BIDSSchema`, `pa.Schema`, `Namespace`, path/URL, or None to use
294+
the module-level default.
292295
293296
Yields:
294297
An Arrow table index for each BIDS dataset.
295298
"""
299+
bids_schema = _resolve(schema)
300+
entity_arrow_schema = bids_schema.arrow_schema
301+
func = partial(_batch_index_func, schema=entity_arrow_schema)
302+
296303
file_count = 0
297304
for dataset, table in (
298305
pbar := tqdm(
299-
_pmap(_batch_index_func, roots, max_workers, executor_cls=executor_cls),
306+
_pmap(func, roots, max_workers, executor_cls=executor_cls),
300307
total=len(roots) if isinstance(roots, Sequence) else None,
301308
disable=show_progress not in {True, "dataset"},
302309
)
@@ -306,9 +313,12 @@ def batch_index_dataset(
306313
yield table
307314

308315

309-
def _batch_index_func(root: str | PathT) -> tuple[str | None, pa.Table]:
316+
def _batch_index_func(
317+
root: str | PathT,
318+
schema: pa.Schema | None = None,
319+
) -> tuple[str | None, pa.Table]:
310320
dataset, _ = _get_bids_dataset(root)
311-
table = index_dataset(root, max_workers=0, show_progress=False)
321+
table = index_dataset(root, max_workers=0, show_progress=False, schema=schema)
312322
return dataset, table
313323

314324

tests/test_indexing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,12 @@ def test_index_dataset_workers_honor_explicit_schema():
293293
BIDS_EXAMPLES / "ds102", schema=tagged, max_workers=2
294294
)
295295
assert table.schema.metadata[b"test_marker"] == b"tagged"
296+
297+
298+
def test_batch_index_dataset_with_explicit_schema():
299+
s = BIDSSchema.from_path(None)
300+
roots = [p.parent for p in BIDS_EXAMPLES.glob("*/dataset_description.json")][:2]
301+
tables = list(indexing.batch_index_dataset(roots, schema=s))
302+
assert len(tables) == len(roots)
303+
for t in tables:
304+
assert "sub" in t.schema.names

0 commit comments

Comments
 (0)