@@ -47,15 +47,15 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
4747 f"{ path .suffix } files not supported, must be one of { supported } "
4848 )
4949 if path .protocol : # cloud path - download it
50- tempdir = tempfile .mkdtemp ()
51- temp_path = upath .UPath (tempdir ) / path .name
52- logger .info (f"Downloading CCF volume to temporary file { temp_path .as_posix ()} " )
53- temp_path .write_bytes (path .read_bytes ())
54- path = temp_path
55- logger . info ( f"Using CCF volume from { path . as_posix () } " )
56-
57- logger .info (f"Loading CCF volume from { path .as_posix ()} " )
58- volume , _ = nrrd .read (path , index_order = "C" ) # ml, dv, ap
50+ with tempfile .TemporaryDirectory () as tempdir :
51+ temp_path = upath .UPath (tempdir ) / path .name
52+ logger .warning (f"Downloading CCF volume to temporary file { temp_path .as_posix ()} " )
53+ temp_path .write_bytes (path .read_bytes ())
54+ path = temp_path
55+ volume , _ = nrrd . read ( path , index_order = "C" ) # ml, dv, ap
56+ else :
57+ logger .info (f"Using CCF volume from { path .as_posix ()} " )
58+ volume , _ = nrrd .read (path , index_order = "C" ) # ml, dv, ap
5959 ml_dim = AXIS_TO_DIM ["ml" ]
6060 dims = [
6161 (slice (0 , volume .shape [ml_dim ] // 2 ) if dim == ml_dim else slice (None ))
@@ -77,6 +77,7 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
7777 return volume
7878
7979
80+ @functools .cache
8081def get_midline_ccf_ml () -> float :
8182 return (
8283 RESOLUTION_UM
@@ -88,15 +89,21 @@ def get_midline_ccf_ml() -> float:
8889 )
8990
9091
92+ def ccf_to_volume_index (coord : float ) -> int :
93+ return round (coord / RESOLUTION_UM )
94+
95+
9196@functools .cache
9297def get_ccf_structure_tree_df () -> pl .DataFrame :
98+ t0 = time .time ()
9399 path = "https://raw.githubusercontent.com/cortex-lab/allenCCF/master/structure_tree_safe_2017.csv"
94100 logging .info (f"Using CCF structure tree from { path } " )
95- return (
96- pl .read_csv (path )
97- .with_columns (
101+ df = pl .read_csv (path )
102+ len_0 = len (df )
103+ df = (
104+ df .with_columns (
98105 color_hex_int = pl .col ("color_hex_triplet" ).str .to_integer (base = 16 ),
99- color_hex_str = pl .lit ("0x " ) + pl .col ("color_hex_triplet" ),
106+ color_hex_str = pl .lit ("# " ) + pl .col ("color_hex_triplet" ),
100107 )
101108 .with_columns (
102109 r = pl .col ("color_hex_triplet" )
@@ -116,13 +123,52 @@ def get_ccf_structure_tree_df() -> pl.DataFrame:
116123 color_rgb = pl .concat_list ("r" , "g" , "b" ),
117124 )
118125 .drop ("r" , "g" , "b" )
126+ .with_columns (
127+ parent_ids = pl .col ("structure_id_path" )
128+ .str .split ("/" )
129+ .cast (pl .List (int ))
130+ .list .drop_nulls ()
131+ .list .slice (offset = 0 , length = pl .col ("depth" )),
132+ )
133+ )
134+ df = df .join (
135+ other = (
136+ df .explode (pl .col ("parent_ids" ))
137+ .group_by (pl .col ("parent_ids" ).alias ("id" ), maintain_order = True )
138+ .agg (pl .col ("id" ).alias ("child_ids" ))
139+ ),
140+ on = "id" ,
141+ how = "left" ,
142+ ).with_columns (
143+ pl .col ("child_ids" ).fill_null ([]),
144+ is_deepest = ~ pl .col ("id" ).is_in (df ["parent_structure_id" ]),
145+ )
146+ assert not any (df .filter (pl .col ("is_deepest" ))["child_ids" ].to_list ())
147+ # add list of deepest children for each area
148+ df = df .join (
149+ other = (
150+ df .explode ("child_ids" )
151+ .filter (pl .col ("child_ids" ).is_in (df .filter (pl .col ("is_deepest" ))["id" ]))
152+ .group_by ("id" , maintain_order = True )
153+ .agg (
154+ pl .all ().exclude ("child_ids" ).first (),
155+ pl .col ("child_ids" ).alias ("deepest_child_ids" ),
156+ )
157+ ),
158+ on = "id" ,
159+ how = "left" ,
160+ ).with_columns (
161+ pl .col ("deepest_child_ids" ).fill_null ([]),
119162 )
163+ assert len (df ) == len_0
164+ logger .info (f"CCF structure tree loaded in { time .time () - t0 :.2f} s" )
165+ return df
120166
121167
122168def get_ccf_structure_info (ccf_acronym_or_id : str | int ) -> dict :
123169 """
124170 >>> get_ccf_structure_info('MOs')
125- {'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '0x1F9D5A ', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
171+ {'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '#1F9D5A ', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
126172 """
127173 if not isinstance (ccf_acronym_or_id , int ):
128174 ccf_id : int = convert_ccf_acronyms_or_ids (ccf_acronym_or_id )
@@ -140,7 +186,53 @@ def get_ccf_structure_info(ccf_acronym_or_id: str | int) -> dict:
140186 return results [0 ].limit (1 ).to_dicts ()[0 ]
141187
142188
189+ def get_all_parents (ccf_acronym_or_id : str | int ) -> list [str ]:
190+ """
191+ >>> get_all_parents('MOs2/3')
192+ ['root', 'grey', 'CH', 'CTX', 'CTXpl', 'Isocortex', 'MO', 'MOs']
193+ """
194+ info = get_ccf_structure_info (ccf_acronym_or_id )
195+ parent_ids = [int (id_ ) for id_ in info ["structure_id_path" ].split ("/" )[1 :- 2 ]]
196+ parent_acronyms = (
197+ get_ccf_structure_tree_df ()
198+ .filter (
199+ pl .col ("id" ).is_in (parent_ids ),
200+ )["acronym" ]
201+ .to_list ()
202+ )
203+ assert info ["id" ] not in parent_acronyms
204+ return parent_acronyms
205+
206+
207+ def get_all_children (ccf_acronym_or_id : str | int ) -> list [str ]:
208+ """
209+ >>> get_all_children('MOs')
210+ ['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
211+ """
212+ if not isinstance (ccf_acronym_or_id , int ):
213+ ccf_id : int = convert_ccf_acronyms_or_ids (ccf_acronym_or_id )
214+ else :
215+ ccf_id = ccf_acronym_or_id
216+ children = (
217+ get_ccf_structure_tree_df ()
218+ .filter (
219+ pl .col ("structure_id_path" ).str .contains (f"/{ ccf_id } /" ),
220+ ~ pl .col ("structure_id_path" ).str .ends_with (f"/{ ccf_id } /" ),
221+ )["acronym" ]
222+ .to_list ()
223+ )
224+ assert str (ccf_acronym_or_id ) not in children
225+ return children
226+
227+
143228def get_deepest_children (ccf_acronym_or_id : str | int ) -> list [str ]:
229+ """
230+ >>> get_deepest_children('MOs')
231+ ['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
232+ >>> get_deepest_children('MOs1')
233+ ['MOs1']
234+ >>> assert 'VISpor' not in get_deepest_children('VIS')
235+ """
144236 if not isinstance (ccf_acronym_or_id , int ):
145237 try :
146238 ccf_id : int = convert_ccf_acronyms_or_ids (ccf_acronym_or_id )
@@ -160,6 +252,38 @@ def get_deepest_children(ccf_acronym_or_id: str | int) -> list[str]:
160252 ].to_list ()
161253
162254
255+ def group_child_labels_in_slice (
256+ slice_array : npt .NDArray [np .uint32 ],
257+ acronyms_or_ids : Iterable [str | int ],
258+ ) -> npt .NDArray [np .uint32 ]:
259+ """
260+ For a given slice and CCF areas (acronyms or IDs), return a new slice with the labels grouped,
261+ so that all areas in the same group have the same label. For example, passing ["MOS"] would
262+ change the label index of all child areas in the MOs tree to have the MOs value.
263+
264+ >>> mos_id = get_ccf_structure_info('MOs')['id']
265+ >>> slice_array = get_ccf_volume()[:, :, 100]
266+ >>> assert mos_id not in slice_array
267+ >>> new_slice = group_child_labels_in_slice(slice_array, ['MOs'])
268+ >>> assert mos_id in new_slice
269+ """
270+ slice_array = slice_array .copy ()
271+ for ccf_acronym_or_id in acronyms_or_ids :
272+ if not isinstance (ccf_acronym_or_id , int ):
273+ ccf_id : int = convert_ccf_acronyms_or_ids (ccf_acronym_or_id )
274+ else :
275+ ccf_id = ccf_acronym_or_id
276+ children = get_ccf_immediate_children_ids (ccf_id )
277+ if ccf_id in children :
278+ children .remove (ccf_id )
279+ logger .debug (
280+ f"Grouping { children } under { convert_ccf_acronyms_or_ids (ccf_acronym_or_id )} "
281+ )
282+ for child in children :
283+ slice_array [slice_array == child ] = ccf_id
284+ return slice_array
285+
286+
163287def get_ccf_immediate_children_ids (ccf_acronym_or_id : str | int ) -> set [int ]:
164288 """
165289 >>> ids = get_ccf_immediate_children_ids('MOs')
@@ -476,13 +600,58 @@ def get_scatter_image(
476600 return image
477601
478602
603+ def project_first_nonzero_labels (
604+ volume : npt .NDArray ,
605+ axis : int = AXIS_TO_DIM ["dv" ],
606+ ) -> npt .NDArray :
607+ """
608+ Project the first non-zero label encountered from one side of the 3D volume.
609+
610+ Parameters:
611+ volume (np.ndarray): 3D array containing non-zero area labels.
612+ axis (int): Axis along which to project (0, 1, or 2).
613+
614+ Returns:
615+ np.ndarray: 2D array with the projected labels.
616+ """
617+ if volume .ndim != 3 :
618+ raise ValueError (f"Volume must be 3D: { volume .shape = } " )
619+ dims = tuple (range (volume .ndim ))
620+ if axis not in dims :
621+ raise ValueError ("Axis must be 0, 1, or 2." )
622+ plane_dims = [d for d in dims if d != axis ]
623+ mask = volume > 0
624+ idx_along_projection_axis = np .argmax (mask , axis = axis )
625+ idx_in_plane_axes = [np .arange (volume .shape [d ]) for d in plane_dims ]
626+ if axis == 0 :
627+ projection = volume [
628+ idx_along_projection_axis ,
629+ idx_in_plane_axes [0 ][:, None ],
630+ idx_in_plane_axes [1 ],
631+ ]
632+ elif axis == 1 :
633+ projection = volume [
634+ idx_in_plane_axes [0 ][:, None ],
635+ idx_along_projection_axis ,
636+ idx_in_plane_axes [1 ],
637+ ]
638+ elif axis == 2 :
639+ projection = volume [
640+ idx_in_plane_axes [0 ][:, None ],
641+ idx_in_plane_axes [1 ],
642+ idx_along_projection_axis ,
643+ ]
644+ projection = projection .astype (float )
645+ projection [projection == 0 ] = np .nan
646+ return projection
647+
648+
479649if __name__ == "__main__" :
480650 logging .basicConfig (
481- level = logging .DEBUG ,
651+ level = logging .WARNING ,
482652 format = "%(asctime)s | %(name)s | %(levelname)s | %(funcName)s | %(message)s" ,
483653 datefmt = "%d-%b-%y %H:%M:%S" ,
484654 )
485- logging .getLogger ().setLevel (logging .DEBUG )
486655
487656 import doctest
488657
0 commit comments