Skip to content

Commit b9e4f9f

Browse files
committed
add bids sort to bids.models from bids.utils
1 parent 9a54461 commit b9e4f9f

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

src/bids/layout/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
except ImportError: # sqlalchemy < 1.4
2626
from sqlalchemy.ext.declarative import declarative_base
2727

28-
from ..utils import listify
28+
from ..utils import listify, bids_sort
2929
from .writing import build_path, write_to_file
3030
from ..config import get_option
3131
from .utils import BIDSMetadata, PaddedInt
@@ -541,8 +541,8 @@ def get_entities(self, metadata=False, values='tags'):
541541

542542
results = query.all()
543543
if values.startswith('obj'):
544-
return {t.entity_name: t.entity for t in results}
545-
return {t.entity_name: t.value for t in results}
544+
return bids_sort({t.entity_name: t.entity for t in results})
545+
return bids_sort({t.entity_name: t.value for t in results})
546546

547547
def copy(self, path_patterns, symbolic_link=False, root=None,
548548
conflicts='fail'):

src/bids/layout/tests/test_layout.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from bids.layout.models import Config
1616
from bids.layout.index import BIDSLayoutIndexer, _check_path_matches_patterns, _regexfy
1717
from bids.layout.utils import PaddedInt
18-
from bids.utils import natural_sort
18+
from bids.utils import natural_sort, bids_sort
1919

2020
from bids.exceptions import (
2121
BIDSChildDatasetError,
@@ -1268,3 +1268,34 @@ def test_get_return_type_dir_with_legacy_config_no_template():
12681268
# because task entity has no directory template
12691269
with pytest.raises(ValueError, match='Return type set to directory'):
12701270
layout.get(target='task', return_type='dir')
1271+
1272+
def test_bids_sort(layout_7t_trt):
1273+
files = layout_7t_trt.get(task='rest', extension='.nii.gz')
1274+
assert len(files) > 0
1275+
from bidsschematools.schema import load_schema
1276+
import random
1277+
import copy
1278+
# we apply bids_sort at the model level, but just to be extra sure
1279+
# we sort here one more time
1280+
first_file_ents_sorted = bids_sort(files[0].get_entities())
1281+
schema_order = list(load_schema().rules.entities) + ['suffix', 'extension', 'datatype']
1282+
1283+
# collect keys from file entity then unsort them
1284+
sorted_keys = list(first_file_ents_sorted.keys())
1285+
unsorted_keys = copy.copy(sorted_keys)
1286+
while unsorted_keys == sorted_keys:
1287+
random.shuffle(unsorted_keys)
1288+
1289+
first_file_ents_unsorted = {}
1290+
for key in unsorted_keys:
1291+
first_file_ents_unsorted[key] = first_file_ents_sorted[key]
1292+
1293+
# check order of sorted entities against schema
1294+
for i, entity in enumerate(sorted_keys):
1295+
for j in sorted_keys[i + 1:]:
1296+
if entity in schema_order and j in schema_order:
1297+
assert schema_order.index(entity) < schema_order.index(j)
1298+
1299+
assert list(first_file_ents_unsorted.keys()) != sorted_keys
1300+
assert list(bids_sort(first_file_ents_unsorted).keys()) == sorted_keys
1301+

src/bids/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,28 @@ def validate_multiple(val, retval=None):
187187
if len(val) == 1:
188188
return val[0]
189189
return val
190+
191+
def bids_sort(unsorted: dict):
192+
f"""
193+
Sorts filename entity dictionaries according to their order as defined in
194+
schema.rules.entities as well as suffix, extension. Lastly, appends datatype
195+
to the end of the sort to accomodate pybids datastructures.
196+
197+
Parameters
198+
----------
199+
unsorted: dict
200+
A dictionary containing bids file entities and their values.
201+
202+
Returns
203+
-------
204+
sorted_bids: dict
205+
206+
"""
207+
from bidsschematools.schema import load_schema
208+
from bidsschematools.types.namespace import Namespace
209+
_schema = load_schema()
210+
entity_order = list(_schema.rules.entities) + ['suffix', 'extension', 'datatype']
211+
212+
sorted_bids = {k: unsorted[k] for k in sorted(unsorted, key=lambda k: entity_order.index(k) if k in entity_order else len(entity_order))}
213+
214+
return sorted_bids

0 commit comments

Comments
 (0)