55while leveraging bids2table's superior performance.
66"""
77
8- from pathlib import Path
9- from typing import Union , Optional , List , Dict , Any
108import warnings
9+ from pathlib import Path
10+ from typing import Any , Dict , List , Optional , Union
1111
1212import pandas as pd
1313import pyarrow as pa
@@ -54,7 +54,7 @@ def __init__(
5454 cache_path : Optional [Path ] = None ,
5555 database_path : Optional [Path ] = None ,
5656 reset_database : bool = False ,
57- ** kwargs
57+ ** kwargs ,
5858 ):
5959 # Initialize BIDSLayout with dataset indexing.
6060 self .root = Path (root ).absolute ()
@@ -66,12 +66,12 @@ def __init__(
6666 "database_path is deprecated, use cache_path instead. "
6767 "Note: cache uses parquet format, not SQLite." ,
6868 DeprecationWarning ,
69- stacklevel = 2
69+ stacklevel = 2 ,
7070 )
7171
7272 # Set cache path
7373 if cache_path is None :
74- self .cache_path = self .root / ' .bids2table_cache.parquet'
74+ self .cache_path = self .root / " .bids2table_cache.parquet"
7575 else :
7676 self .cache_path = Path (cache_path )
7777
@@ -106,7 +106,7 @@ def _load_or_create_index(self) -> pa.Table:
106106 f"Failed to load cache from { self .cache_path } : { e } . "
107107 "Re-indexing dataset." ,
108108 UserWarning ,
109- stacklevel = 3
109+ stacklevel = 3 ,
110110 )
111111
112112 # Create new index
@@ -121,7 +121,7 @@ def _load_or_create_index(self) -> pa.Table:
121121 warnings .warn (
122122 f"Failed to save cache to { self .cache_path } : { e } " ,
123123 UserWarning ,
124- stacklevel = 3
124+ stacklevel = 3 ,
125125 )
126126
127127 return tab
@@ -145,7 +145,7 @@ def _add_derivatives(self, derivatives: Union[str, Path, List[Union[str, Path]]]
145145 warnings .warn (
146146 f"Derivative path does not exist: { deriv_path } " ,
147147 UserWarning ,
148- stacklevel = 3
148+ stacklevel = 3 ,
149149 )
150150 continue
151151
@@ -156,11 +156,7 @@ def _add_derivatives(self, derivatives: Union[str, Path, List[Union[str, Path]]]
156156 if deriv_tabs :
157157 self ._tab = pa .concat_tables ([self ._tab ] + deriv_tabs )
158158
159- def get (
160- self ,
161- return_type : str = 'file' ,
162- ** entities
163- ) -> List [Union [str , BIDSFile ]]:
159+ def get (self , return_type : str = "file" , ** entities ) -> List [Union [str , BIDSFile ]]:
164160 """
165161 Query files by BIDS entities.
166162
@@ -195,7 +191,7 @@ def get(
195191 warnings .warn (
196192 f"Unknown entity '{ key } ' (not in dataset columns)" ,
197193 UserWarning ,
198- stacklevel = 2
194+ stacklevel = 2 ,
199195 )
200196 continue
201197
@@ -217,14 +213,14 @@ def get(
217213 result_df = result_df [result_df [key ] == value ]
218214
219215 # Return based on return_type
220- if return_type == ' filename' :
221- return result_df [' path' ].tolist ()
222- elif return_type == ' file' :
223- return [BIDSFile (p ) for p in result_df [' path' ].tolist ()]
224- elif return_type == 'id' :
216+ if return_type == " filename" :
217+ return result_df [" path" ].tolist ()
218+ elif return_type == " file" :
219+ return [BIDSFile (p ) for p in result_df [" path" ].tolist ()]
220+ elif return_type == "id" :
225221 return result_df .index .tolist ()
226- elif return_type == ' dir' :
227- dirs = result_df [' path' ].apply (lambda p : str (Path (p ).parent ))
222+ elif return_type == " dir" :
223+ dirs = result_df [" path" ].apply (lambda p : str (Path (p ).parent ))
228224 return sorted (dirs .unique ().tolist ())
229225 else :
230226 raise ValueError (
@@ -244,11 +240,11 @@ def _map_entity_key(self, key: str) -> str:
244240 """
245241 # Common mappings
246242 mapping = {
247- ' subject' : ' sub' ,
248- ' session' : ' ses' ,
249- ' extension' : ' ext' ,
250- ' datatype' : ' datatype' ,
251- ' suffix' : ' suffix' ,
243+ " subject" : " sub" ,
244+ " session" : " ses" ,
245+ " extension" : " ext" ,
246+ " datatype" : " datatype" ,
247+ " suffix" : " suffix" ,
252248 }
253249 return mapping .get (key , key )
254250
@@ -275,9 +271,9 @@ def get_subjects(self, **filters) -> List[str]:
275271 key = self ._map_entity_key (key )
276272 if key in filtered_df .columns :
277273 filtered_df = filtered_df [filtered_df [key ] == value ]
278- subjects = filtered_df [' sub' ].dropna ().unique ()
274+ subjects = filtered_df [" sub" ].dropna ().unique ()
279275 else :
280- subjects = self .df [' sub' ].dropna ().unique ()
276+ subjects = self .df [" sub" ].dropna ().unique ()
281277
282278 return sorted (subjects .tolist ())
283279
@@ -302,15 +298,15 @@ def get_sessions(self, subject: Optional[str] = None, **filters) -> List[str]:
302298
303299 # Filter by subject if provided
304300 if subject is not None :
305- result_df = result_df [result_df [' sub' ] == subject ]
301+ result_df = result_df [result_df [" sub" ] == subject ]
306302
307303 # Apply additional filters
308304 for key , value in filters .items ():
309305 key = self ._map_entity_key (key )
310306 if key in result_df .columns :
311307 result_df = result_df [result_df [key ] == value ]
312308
313- sessions = result_df [' ses' ].dropna ().unique ()
309+ sessions = result_df [" ses" ].dropna ().unique ()
314310 return sorted (sessions .tolist ())
315311
316312 def get_metadata (self , path : str ) -> Dict [str , Any ]:
@@ -384,9 +380,30 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
384380
385381 # Extract unique values for each entity column
386382 # Standard BIDS entities that might be present
387- entity_cols = ['sub' , 'ses' , 'task' , 'acq' , 'ce' , 'rec' , 'dir' , 'run' ,
388- 'mod' , 'echo' , 'flip' , 'inv' , 'mt' , 'part' , 'recording' ,
389- 'suffix' , 'space' , 'res' , 'den' , 'label' , 'desc' , 'datatype' ]
383+ entity_cols = [
384+ "sub" ,
385+ "ses" ,
386+ "task" ,
387+ "acq" ,
388+ "ce" ,
389+ "rec" ,
390+ "dir" ,
391+ "run" ,
392+ "mod" ,
393+ "echo" ,
394+ "flip" ,
395+ "inv" ,
396+ "mt" ,
397+ "part" ,
398+ "recording" ,
399+ "suffix" ,
400+ "space" ,
401+ "res" ,
402+ "den" ,
403+ "label" ,
404+ "desc" ,
405+ "datatype" ,
406+ ]
390407
391408 entities = {}
392409 for col in entity_cols :
@@ -398,10 +415,7 @@ def get_entities(self, **filters) -> Dict[str, List[str]]:
398415 return entities
399416
400417 def add_custom_entity (
401- self ,
402- name : str ,
403- values : Union [List , Dict , Any ],
404- overwrite : bool = False
418+ self , name : str , values : Union [List , Dict , Any ], overwrite : bool = False
405419 ):
406420 """
407421 Add a custom entity column to the layout.
@@ -446,20 +460,20 @@ def add_custom_entity(
446460 elif isinstance (values , dict ):
447461 # Dict: map from key (assume subject or file path)
448462 # Try to detect if keys are subjects or paths
449- if values and list (values .keys ())[0 ] in self .df [' sub' ].values :
463+ if values and list (values .keys ())[0 ] in self .df [" sub" ].values :
450464 # Keys are subjects
451- self .df [name ] = self .df [' sub' ].map (values )
465+ self .df [name ] = self .df [" sub" ].map (values )
452466 else :
453467 # Keys are file paths or other
454- self .df [name ] = self .df [' path' ].map (values )
468+ self .df [name ] = self .df [" path" ].map (values )
455469 else :
456470 # Scalar or array-like: assign directly
457471 self .df [name ] = values
458472
459473 def __repr__ (self ) -> str :
460474 """String representation of layout."""
461- n_subjects = self .df [' sub' ].nunique ()
462- n_sessions = self .df [' ses' ].nunique ()
475+ n_subjects = self .df [" sub" ].nunique ()
476+ n_sessions = self .df [" ses" ].nunique ()
463477 n_files = len (self .df )
464478
465479 return (
0 commit comments