Skip to content

Commit 3719657

Browse files
committed
Typing fixes
1 parent 015f283 commit 3719657

5 files changed

Lines changed: 42 additions & 39 deletions

File tree

bids2table/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,5 @@
159159
get_bids_entity_arrow_schema,
160160
format_bids_path,
161161
)
162-
from ._pathlib import Path, cloudpathlib_is_available
162+
from ._pathlib import cloudpathlib_is_available
163163
from ._version import *

bids2table/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import pyarrow.parquet as pq
77

88
import bids2table as b2t2
9-
from bids2table import Path
109
from bids2table._logging import setup_logger
10+
from bids2table._pathlib import as_path
1111

1212
_logger = setup_logger(__package__)
1313

@@ -117,7 +117,7 @@ def _index_command(args: argparse.Namespace):
117117
root = []
118118
for path in args.root:
119119
if glob.has_magic(path):
120-
path = Path(path)
120+
path = as_path(path)
121121
paths = list(path.parent.glob(path.name))
122122
root.extend(paths)
123123
else:

bids2table/_indexing.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
validate_bids_entities,
2222
)
2323
from ._logging import setup_logger
24-
from ._pathlib import Path
24+
from ._pathlib import PathT, as_path
2525

2626
_BIDS_SUBJECT_DIR_PATTERN = re.compile(r"sub-[a-zA-Z0-9]+")
2727

@@ -109,7 +109,7 @@ def get_arrow_schema() -> pa.Schema:
109109
return schema
110110

111111

112-
def get_column_names() -> enum.EnumType:
112+
def get_column_names() -> enum.StrEnum:
113113
"""Get an enum of the BIDS index columns."""
114114
# TODO: It might be nice if the column names were statically available. One option
115115
# would be to generate a static _schema.py module at install time (similar to how
@@ -127,11 +127,11 @@ def get_column_names() -> enum.EnumType:
127127

128128

129129
def find_bids_datasets(
130-
root: str | Path,
130+
root: str | PathT,
131131
exclude: str | list[str] | None = None,
132132
follow_symlinks: bool = True,
133133
log_frequency: int = 100,
134-
) -> Generator[Path, None, None]:
134+
) -> Generator[PathT, None, None]:
135135
"""Find all BIDS datasets under a root directory.
136136
137137
Args:
@@ -143,8 +143,7 @@ def find_bids_datasets(
143143
Yields:
144144
Root paths of all BIDS datasets under `root`.
145145
"""
146-
if isinstance(root, str):
147-
root = Path(root)
146+
root = as_path(root)
148147

149148
dir_count = 0
150149
ds_count = 0
@@ -178,7 +177,7 @@ def find_bids_datasets(
178177

179178

180179
def index_dataset(
181-
root: str | Path,
180+
root: str | PathT,
182181
include_subjects: str | list[str] | None = None,
183182
max_workers: int | None = 0,
184183
chunksize: int = 32,
@@ -203,8 +202,7 @@ def index_dataset(
203202
Returns:
204203
An Arrow table index of the BIDS dataset.
205204
"""
206-
if isinstance(root, str):
207-
root = Path(root)
205+
root = as_path(root)
208206

209207
schema = get_arrow_schema()
210208

@@ -243,7 +241,7 @@ def index_dataset(
243241

244242

245243
def batch_index_dataset(
246-
roots: list[str | Path],
244+
roots: list[str | PathT],
247245
max_workers: int | None = 0,
248246
executor_cls: type[Executor] = ProcessPoolExecutor,
249247
show_progress: bool = False,
@@ -275,13 +273,13 @@ def batch_index_dataset(
275273
yield table
276274

277275

278-
def _batch_index_func(root: str | Path) -> tuple[str, pa.Table]:
276+
def _batch_index_func(root: str | PathT) -> tuple[str | None, pa.Table]:
279277
dataset, _ = _get_bids_dataset(root)
280278
table = index_dataset(root, max_workers=0, show_progress=False)
281279
return dataset, table
282280

283281

284-
def _get_bids_dataset(path: str | Path) -> tuple[str | None, Path | None]:
282+
def _get_bids_dataset(path: str | PathT) -> tuple[str | None, PathT | None]:
285283
"""Get the BIDS dataset that the path belongs to, if any.
286284
287285
Return the dataset directory name and the full dataset path. For nested derivatives
@@ -290,13 +288,10 @@ def _get_bids_dataset(path: str | Path) -> tuple[str | None, Path | None]:
290288
291289
Note that the name is extracted from the path, not the dataset description JSON.
292290
"""
293-
if isinstance(path, str):
294-
path = Path(path)
295-
296-
parent = path
291+
parent = as_path(path)
297292
parts: list[str] = []
298293
scanning = False
299-
top_idx = None
294+
top_idx = 0
300295
root = None
301296

302297
while parent.name:
@@ -319,24 +314,24 @@ def _get_bids_dataset(path: str | Path) -> tuple[str | None, Path | None]:
319314
return dataset, root
320315

321316

322-
def _is_bids_dataset(path: Path) -> bool:
317+
def _is_bids_dataset(path: PathT) -> bool:
323318
"""Test if path is a BIDS dataset root directory."""
324319
# Check if contains a dataset_description.json or any subject directories. Note,
325320
# it's common for ppl to forget the dataset description, so let's not be too strict.
326321
description_exists = (path / "dataset_description.json").exists()
327322
return description_exists or _contains_bids_subject_dirs(path)
328323

329324

330-
def _contains_bids_subject_dirs(root: Path) -> bool:
325+
def _contains_bids_subject_dirs(root: PathT) -> bool:
331326
"""Check if a path contains one or more BIDS subject dirs."""
332327
# Nb, this will return on the first matching path thanks to the generator.
333328
return any(_is_bids_subject_dir(path) for path in root.glob("sub-*"))
334329

335330

336331
def _find_bids_subject_dirs(
337-
root: Path,
332+
root: PathT,
338333
include_subjects: str | list[str] | None = None,
339-
) -> list[Path]:
334+
) -> list[PathT]:
340335
"""Find all BIDS subject dirs contained in a root directory.
341336
342337
Note, only looks one level down. Does not find nested subject directories, e.g. in
@@ -352,7 +347,7 @@ def _find_bids_subject_dirs(
352347
return paths
353348

354349

355-
def _is_bids_subject_dir(path: Path) -> bool:
350+
def _is_bids_subject_dir(path: PathT) -> bool:
356351
"""Check if a path is a BIDS subject directory."""
357352
# NOTE: not checking if the path is in fact a directory.
358353
# This is a slow op, especially on cloud. Can assume that there are no files
@@ -362,7 +357,7 @@ def _is_bids_subject_dir(path: Path) -> bool:
362357

363358

364359
def _index_bids_subject_dir(
365-
path: Path,
360+
path: PathT,
366361
schema: pa.Schema | None = None,
367362
dataset: str | None = None,
368363
) -> tuple[str, pa.Table]:
@@ -394,7 +389,7 @@ def _index_bids_subject_dir(
394389
return subject, table
395390

396391

397-
def _is_bids_file(path: Path) -> bool:
392+
def _is_bids_file(path: PathT) -> bool:
398393
"""Check if file is a BIDS file.
399394
400395
Not very exact, but hopefully good enough.
@@ -423,7 +418,7 @@ def _is_bids_file(path: Path) -> bool:
423418
return True
424419

425420

426-
def _is_bids_json_sidecar(path: Path) -> bool:
421+
def _is_bids_json_sidecar(path: PathT) -> bool:
427422
"""Quick check if a file is a JSON sidecar."""
428423
# Quick check if path suffix is not json.
429424
if path.suffix != ".json":

bids2table/_pathlib.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
from pathlib import Path
22

33
try:
4-
# Overshadow pathlib Path.
5-
from cloudpathlib import AnyPath as Path
4+
import cloudpathlib
5+
from cloudpathlib import AnyPath, CloudPath
66

77
_CLOUDPATHLIB_AVAILABLE = True
8+
9+
# Set unsigned client as default for s3:// paths
10+
cloudpathlib.S3Client(no_sign_request=True).set_as_default_client()
11+
812
except ImportError:
13+
AnyPath = CloudPath = Path
14+
915
_CLOUDPATHLIB_AVAILABLE = False
1016

11-
__all__ = ["Path", "cloudpathlib_is_available"]
17+
__all__ = ["PathT", "as_path", "cloudpathlib_is_available"]
1218

19+
PathT = Path | CloudPath
1320

14-
def cloudpathlib_is_available() -> bool:
15-
"""Check if cloudpathlib is available."""
16-
return _CLOUDPATHLIB_AVAILABLE
1721

22+
def as_path(path: str | PathT) -> PathT:
23+
"""Cast input to a `Path` type."""
24+
if isinstance(path, str):
25+
return AnyPath(path)
26+
return path
1827

19-
if _CLOUDPATHLIB_AVAILABLE:
20-
# Set unsigned client as default for s3:// paths
21-
from cloudpathlib import S3Client
2228

23-
client = S3Client(no_sign_request=True)
24-
client.set_as_default_client()
29+
def cloudpathlib_is_available() -> bool:
30+
"""Check if cloudpathlib is available."""
31+
return _CLOUDPATHLIB_AVAILABLE

tests/test_indexing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def test_batch_index_dataset(max_workers: int):
110110
def test_get_bids_dataset(path: str, expected_name: str):
111111
name, dataset_path = indexing._get_bids_dataset(BIDS_EXAMPLES / path)
112112
assert name == expected_name
113+
assert dataset_path is not None
113114
assert indexing._contains_bids_subject_dirs(dataset_path)
114115

115116

0 commit comments

Comments
 (0)