Skip to content

Commit 716f521

Browse files
authored
Merge pull request #79 from childmindresearch/b2t/pybids-compat
Update pybids compatibility flattening
2 parents efce21c + af262b3 commit 716f521

2 files changed

Lines changed: 52 additions & 25 deletions

File tree

bids2table/pybids/_bidsfile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
BIDS entities from file paths.
66
"""
77

8-
from typing import Any, Dict, Optional
8+
from typing import Any
99

1010
from .._entities import parse_bids_entities
1111

@@ -33,9 +33,9 @@ def __init__(self, path: str):
3333
path: Path to BIDS file (absolute or relative)
3434
"""
3535
self.path = str(path)
36-
self._entities: Optional[Dict[str, Any]] = None
36+
self._entities: dict[str, Any] | None = None
3737

38-
def get_entities(self) -> Dict[str, Any]:
38+
def get_entities(self) -> dict[str, Any]:
3939
"""
4040
Parse and return BIDS entities from filename.
4141

bids2table/pybids/_layout.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import warnings
99
from pathlib import Path
10-
from typing import Any, Dict, List, Optional, Union
10+
from typing import Any
1111

1212
import pandas as pd
1313
import pyarrow as pa
@@ -34,7 +34,6 @@ class BIDSLayout:
3434
3535
Args:
3636
root: Path to BIDS dataset root
37-
validate: Whether to validate BIDS compliance (currently ignored)
3837
derivatives: Path(s) to derivative datasets to include
3938
cache_path: Path to parquet cache file (default: {root}/.bids2table_cache.parquet)
4039
database_path: Legacy parameter (ignored, use cache_path instead)
@@ -48,11 +47,10 @@ class BIDSLayout:
4847

4948
def __init__(
5049
self,
51-
root: Union[str, Path],
52-
validate: bool = True,
53-
derivatives: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
54-
cache_path: Optional[Path] = None,
55-
database_path: Optional[Path] = None,
50+
root: str | Path,
51+
derivatives: str | Path | list[str | Path] | None = None,
52+
cache_path: Path | None = None,
53+
database_path: Path | None = None,
5654
reset_database: bool = False,
5755
**kwargs,
5856
):
@@ -75,10 +73,8 @@ def __init__(
7573
else:
7674
self.cache_path = Path(cache_path)
7775

78-
# b2t always validates, so we can contiue
79-
8076
# Load or create index
81-
self._tab = self._load_or_create_index().flatten()
77+
self._tab = self._load_or_create_index()
8278

8379
# Handle derivatives
8480
if derivatives is not None:
@@ -88,17 +84,45 @@ def __init__(
8884
entity_schema = self._tab.schema
8985
self._entity_map = {}
9086
for entity in entity_schema:
91-
# Pull the name (uesd by B2T) and entity (used by pyBIDS) labels
87+
# Pull the name (used by B2T) and entity (used by pyBIDS) labels
9288
name = entity.metadata[b"name"]
9389
dname = entity.metadata.get(b"entity", name)
9490
# Decode them from bytestrings into real strings, and store
9591
# so that either the entity or shortname will return appropriately
9692
self._entity_map[dname.decode("utf-8")] = name.decode("utf-8")
9793
self._entity_map[name.decode("utf-8")] = name.decode("utf-8")
9894

95+
# Flatten extra entities after
96+
self._flatten_extra_entities()
97+
9998
# Convert to pandas DataFrame for querying
10099
self.df = self._tab.to_pandas(types_mapper=pd.ArrowDtype)
101100

101+
def _flatten_extra_entities(self) -> None:
102+
"""Flatten extra entities in the table."""
103+
if "extra_entities" not in self._tab.column_names:
104+
return
105+
106+
idx = self._tab.schema.get_field_index("extra_entities")
107+
dicts = [
108+
dict(r) if r else {} for r in self._tab.column("extra_entities").to_pylist()
109+
]
110+
all_keys = set().union(*dicts)
111+
112+
self._tab = self._tab.remove_column(idx)
113+
if all_keys:
114+
for k in all_keys:
115+
self._tab = self._tab.append_column(
116+
pa.field(k, pa.string()), pa.array([d.get(k) for d in dicts])
117+
)
118+
self._entity_map[k] = k
119+
120+
cols = [c for c in self._tab.column_names if c not in ("root", "path")] + [
121+
"root",
122+
"path",
123+
]
124+
self._tab = self._tab.select(cols)
125+
102126
def _load_or_create_index(self) -> pa.Table:
103127
"""
104128
Load cached index or create new one.
@@ -138,7 +162,7 @@ def _load_or_create_index(self) -> pa.Table:
138162

139163
return tab
140164

141-
def _add_derivatives(self, derivatives: Union[str, Path, List[Union[str, Path]]]):
165+
def _add_derivatives(self, derivatives: str | Path | list[str | Path]):
142166
"""
143167
Add derivative datasets to the index.
144168
@@ -168,7 +192,7 @@ def _add_derivatives(self, derivatives: Union[str, Path, List[Union[str, Path]]]
168192
if deriv_tabs:
169193
self._tab = pa.concat_tables([self._tab] + deriv_tabs)
170194

171-
def get(self, return_type: str = "file", **entities) -> List[Union[str, BIDSFile]]:
195+
def get(self, return_type: str = "file", **entities) -> list[str | BIDSFile]:
172196
"""
173197
Query files by BIDS entities.
174198
@@ -241,7 +265,7 @@ def get(self, return_type: str = "file", **entities) -> List[Union[str, BIDSFile
241265
"Valid options: 'file', 'filename', 'id', 'dir'"
242266
)
243267

244-
def get_subjects(self, **filters) -> List[str]:
268+
def get_subjects(self, **filters) -> list[str]:
245269
"""
246270
Get list of unique subject IDs.
247271
@@ -270,7 +294,7 @@ def get_subjects(self, **filters) -> List[str]:
270294

271295
return sorted(subjects.tolist())
272296

273-
def get_sessions(self, subject: Optional[str] = None, **filters) -> List[str]:
297+
def get_sessions(self, subject: str | None = None, **filters) -> list[str]:
274298
"""
275299
Get list of unique session IDs.
276300
@@ -302,7 +326,7 @@ def get_sessions(self, subject: Optional[str] = None, **filters) -> List[str]:
302326
sessions = result_df["ses"].dropna().unique()
303327
return sorted(sessions.tolist())
304328

305-
def get_metadata(self, path: str) -> Dict[str, Any]:
329+
def get_metadata(self, path: str) -> dict[str, Any]:
306330
"""
307331
Load metadata from JSON sidecar(s) for a given file.
308332
@@ -342,7 +366,7 @@ def get_file(self, path: str) -> BIDSFile:
342366
"""
343367
return BIDSFile(path)
344368

345-
def get_entities(self, **filters) -> Dict[str, List[str]]:
369+
def get_entities(self, **filters) -> dict[str, list[str]]:
346370
"""
347371
Get dictionary of all entities and their unique values.
348372
@@ -374,10 +398,6 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
374398
# Extract unique values for each entity column
375399
entities = {}
376400
for ekey, evalue in self._entity_map.items():
377-
# TODO: think about if we want to handle extra entities here?
378-
# Disabled for now since unique and lists don't play nice
379-
if evalue == "extra_entities":
380-
continue
381401
if evalue in filtered_df.columns:
382402
unique_vals = filtered_df[evalue].dropna().unique().tolist()
383403
if unique_vals: # Only include if not empty
@@ -386,7 +406,10 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
386406
return entities
387407

388408
def add_custom_entity(
389-
self, name: str, values: Union[List, Dict, Any], overwrite: bool = False
409+
self,
410+
name: str,
411+
values: list[Any] | dict[str, Any] | Any,
412+
overwrite: bool = False,
390413
):
391414
"""
392415
Add a custom entity column to the layout.
@@ -453,3 +476,7 @@ def __repr__(self) -> str:
453476
f"BIDSLayout(root='{self.root}', "
454477
f"subjects={n_subjects}, sessions={n_sessions}, files={n_files})"
455478
)
479+
480+
def to_df(self) -> pd.DataFrame:
481+
"""Explicit method to return converted dataframe, mirroring pybids."""
482+
return self.df

0 commit comments

Comments
 (0)