77
88import warnings
99from pathlib import Path
10- from typing import Any , Dict , List , Optional , Union
10+ from typing import Any
1111
1212import pandas as pd
1313import pyarrow as pa
@@ -34,7 +34,6 @@ class BIDSLayout:
3434
3535 Args:
3636 root: Path to BIDS dataset root
37- validate: Whether to validate BIDS compliance (currently ignored)
3837 derivatives: Path(s) to derivative datasets to include
3938 cache_path: Path to parquet cache file (default: {root}/.bids2table_cache.parquet)
4039 database_path: Legacy parameter (ignored, use cache_path instead)
@@ -48,11 +47,10 @@ class BIDSLayout:
4847
4948 def __init__ (
5049 self ,
51- root : Union [str , Path ],
52- validate : bool = True ,
53- derivatives : Optional [Union [str , Path , List [Union [str , Path ]]]] = None ,
54- cache_path : Optional [Path ] = None ,
55- database_path : Optional [Path ] = None ,
50+ root : str | Path ,
51+ derivatives : str | Path | list [str | Path ] | None = None ,
52+ cache_path : Path | None = None ,
53+ database_path : Path | None = None ,
5654 reset_database : bool = False ,
5755 ** kwargs ,
5856 ):
@@ -75,10 +73,8 @@ def __init__(
7573 else :
7674 self .cache_path = Path (cache_path )
7775
78- # b2t always validates, so we can contiue
79-
8076 # Load or create index
81- self ._tab = self ._load_or_create_index (). flatten ()
77+ self ._tab = self ._load_or_create_index ()
8278
8379 # Handle derivatives
8480 if derivatives is not None :
@@ -88,17 +84,45 @@ def __init__(
8884 entity_schema = self ._tab .schema
8985 self ._entity_map = {}
9086 for entity in entity_schema :
91- # Pull the name (uesd by B2T) and entity (used by pyBIDS) labels
87+ # Pull the name (used by B2T) and entity (used by pyBIDS) labels
9288 name = entity .metadata [b"name" ]
9389 dname = entity .metadata .get (b"entity" , name )
9490 # Decode them from bytestrings into real strings, and store
9591 # so that either the entity or shortname will return appropriately
9692 self ._entity_map [dname .decode ("utf-8" )] = name .decode ("utf-8" )
9793 self ._entity_map [name .decode ("utf-8" )] = name .decode ("utf-8" )
9894
95+ # Flatten extra entities after
96+ self ._flatten_extra_entities ()
97+
9998 # Convert to pandas DataFrame for querying
10099 self .df = self ._tab .to_pandas (types_mapper = pd .ArrowDtype )
101100
101+ def _flatten_extra_entities (self ) -> None :
102+ """Flatten extra entities in the table."""
103+ if "extra_entities" not in self ._tab .column_names :
104+ return
105+
106+ idx = self ._tab .schema .get_field_index ("extra_entities" )
107+ dicts = [
108+ dict (r ) if r else {} for r in self ._tab .column ("extra_entities" ).to_pylist ()
109+ ]
110+ all_keys = set ().union (* dicts )
111+
112+ self ._tab = self ._tab .remove_column (idx )
113+ if all_keys :
114+ for k in all_keys :
115+ self ._tab = self ._tab .append_column (
116+ pa .field (k , pa .string ()), pa .array ([d .get (k ) for d in dicts ])
117+ )
118+ self ._entity_map [k ] = k
119+
120+ cols = [c for c in self ._tab .column_names if c not in ("root" , "path" )] + [
121+ "root" ,
122+ "path" ,
123+ ]
124+ self ._tab = self ._tab .select (cols )
125+
102126 def _load_or_create_index (self ) -> pa .Table :
103127 """
104128 Load cached index or create new one.
@@ -138,7 +162,7 @@ def _load_or_create_index(self) -> pa.Table:
138162
139163 return tab
140164
141- def _add_derivatives (self , derivatives : Union [ str , Path , List [ Union [ str , Path ]] ]):
165+ def _add_derivatives (self , derivatives : str | Path | list [ str | Path ]):
142166 """
143167 Add derivative datasets to the index.
144168
@@ -168,7 +192,7 @@ def _add_derivatives(self, derivatives: Union[str, Path, List[Union[str, Path]]]
168192 if deriv_tabs :
169193 self ._tab = pa .concat_tables ([self ._tab ] + deriv_tabs )
170194
171- def get (self , return_type : str = "file" , ** entities ) -> List [ Union [ str , BIDSFile ] ]:
195+ def get (self , return_type : str = "file" , ** entities ) -> list [ str | BIDSFile ]:
172196 """
173197 Query files by BIDS entities.
174198
@@ -241,7 +265,7 @@ def get(self, return_type: str = "file", **entities) -> List[Union[str, BIDSFile
241265 "Valid options: 'file', 'filename', 'id', 'dir'"
242266 )
243267
244- def get_subjects (self , ** filters ) -> List [str ]:
268+ def get_subjects (self , ** filters ) -> list [str ]:
245269 """
246270 Get list of unique subject IDs.
247271
@@ -270,7 +294,7 @@ def get_subjects(self, **filters) -> List[str]:
270294
271295 return sorted (subjects .tolist ())
272296
273- def get_sessions (self , subject : Optional [ str ] = None , ** filters ) -> List [str ]:
297+ def get_sessions (self , subject : str | None = None , ** filters ) -> list [str ]:
274298 """
275299 Get list of unique session IDs.
276300
@@ -302,7 +326,7 @@ def get_sessions(self, subject: Optional[str] = None, **filters) -> List[str]:
302326 sessions = result_df ["ses" ].dropna ().unique ()
303327 return sorted (sessions .tolist ())
304328
305- def get_metadata (self , path : str ) -> Dict [str , Any ]:
329+ def get_metadata (self , path : str ) -> dict [str , Any ]:
306330 """
307331 Load metadata from JSON sidecar(s) for a given file.
308332
@@ -342,7 +366,7 @@ def get_file(self, path: str) -> BIDSFile:
342366 """
343367 return BIDSFile (path )
344368
345- def get_entities (self , ** filters ) -> Dict [str , List [str ]]:
369+ def get_entities (self , ** filters ) -> dict [str , list [str ]]:
346370 """
347371 Get dictionary of all entities and their unique values.
348372
@@ -374,10 +398,6 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
374398 # Extract unique values for each entity column
375399 entities = {}
376400 for ekey , evalue in self ._entity_map .items ():
377- # TODO: think about if we want to handle extra entities here?
378- # Disabled for now since unique and lists don't play nice
379- if evalue == "extra_entities" :
380- continue
381401 if evalue in filtered_df .columns :
382402 unique_vals = filtered_df [evalue ].dropna ().unique ().tolist ()
383403 if unique_vals : # Only include if not empty
@@ -386,7 +406,10 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
386406 return entities
387407
388408 def add_custom_entity (
389- self , name : str , values : Union [List , Dict , Any ], overwrite : bool = False
409+ self ,
410+ name : str ,
411+ values : list [Any ] | dict [str , Any ] | Any ,
412+ overwrite : bool = False ,
390413 ):
391414 """
392415 Add a custom entity column to the layout.
@@ -453,3 +476,7 @@ def __repr__(self) -> str:
453476 f"BIDSLayout(root='{ self .root } ', "
454477 f"subjects={ n_subjects } , sessions={ n_sessions } , files={ n_files } )"
455478 )
479+
480+ def to_df (self ) -> pd .DataFrame :
481+ """Explicit method to return converted dataframe, mirroring pybids."""
482+ return self .df
0 commit comments