Skip to content

Commit 1d9e3a8

Browse files
committed
Update brain heatmap plotting function
1 parent 990faee commit 1d9e3a8

File tree

3 files changed

+1914
-762
lines changed

3 files changed

+1914
-762
lines changed

pdm.lock

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dynamic_routing_analysis/ccf_utils.py

Lines changed: 185 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
4747
f"{path.suffix} files not supported, must be one of {supported}"
4848
)
4949
if path.protocol: # cloud path - download it
50-
tempdir = tempfile.mkdtemp()
51-
temp_path = upath.UPath(tempdir) / path.name
52-
logger.info(f"Downloading CCF volume to temporary file {temp_path.as_posix()}")
53-
temp_path.write_bytes(path.read_bytes())
54-
path = temp_path
55-
logger.info(f"Using CCF volume from {path.as_posix()}")
56-
57-
logger.info(f"Loading CCF volume from {path.as_posix()}")
58-
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
50+
with tempfile.TemporaryDirectory() as tempdir:
51+
temp_path = upath.UPath(tempdir) / path.name
52+
logger.warning(f"Downloading CCF volume to temporary file {temp_path.as_posix()}")
53+
temp_path.write_bytes(path.read_bytes())
54+
path = temp_path
55+
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
56+
else:
57+
logger.info(f"Using CCF volume from {path.as_posix()}")
58+
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
5959
ml_dim = AXIS_TO_DIM["ml"]
6060
dims = [
6161
(slice(0, volume.shape[ml_dim] // 2) if dim == ml_dim else slice(None))
@@ -77,6 +77,7 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
7777
return volume
7878

7979

80+
@functools.cache
8081
def get_midline_ccf_ml() -> float:
8182
return (
8283
RESOLUTION_UM
@@ -88,15 +89,21 @@ def get_midline_ccf_ml() -> float:
8889
)
8990

9091

92+
def ccf_to_volume_index(coord: float) -> int:
93+
return round(coord / RESOLUTION_UM)
94+
95+
9196
@functools.cache
9297
def get_ccf_structure_tree_df() -> pl.DataFrame:
98+
t0 = time.time()
9399
path = "https://raw.githubusercontent.com/cortex-lab/allenCCF/master/structure_tree_safe_2017.csv"
94100
logging.info(f"Using CCF structure tree from {path}")
95-
return (
96-
pl.read_csv(path)
97-
.with_columns(
101+
df = pl.read_csv(path)
102+
len_0 = len(df)
103+
df = (
104+
df.with_columns(
98105
color_hex_int=pl.col("color_hex_triplet").str.to_integer(base=16),
99-
color_hex_str=pl.lit("0x") + pl.col("color_hex_triplet"),
106+
color_hex_str=pl.lit("#") + pl.col("color_hex_triplet"),
100107
)
101108
.with_columns(
102109
r=pl.col("color_hex_triplet")
@@ -116,13 +123,52 @@ def get_ccf_structure_tree_df() -> pl.DataFrame:
116123
color_rgb=pl.concat_list("r", "g", "b"),
117124
)
118125
.drop("r", "g", "b")
126+
.with_columns(
127+
parent_ids=pl.col("structure_id_path")
128+
.str.split("/")
129+
.cast(pl.List(int))
130+
.list.drop_nulls()
131+
.list.slice(offset=0, length=pl.col("depth")),
132+
)
133+
)
134+
df = df.join(
135+
other=(
136+
df.explode(pl.col("parent_ids"))
137+
.group_by(pl.col("parent_ids").alias("id"), maintain_order=True)
138+
.agg(pl.col("id").alias("child_ids"))
139+
),
140+
on="id",
141+
how="left",
142+
).with_columns(
143+
pl.col("child_ids").fill_null([]),
144+
is_deepest=~pl.col("id").is_in(df["parent_structure_id"]),
145+
)
146+
assert not any(df.filter(pl.col("is_deepest"))["child_ids"].to_list())
147+
# add list of deepest children for each area
148+
df = df.join(
149+
other=(
150+
df.explode("child_ids")
151+
.filter(pl.col("child_ids").is_in(df.filter(pl.col("is_deepest"))["id"]))
152+
.group_by("id", maintain_order=True)
153+
.agg(
154+
pl.all().exclude("child_ids").first(),
155+
pl.col("child_ids").alias("deepest_child_ids"),
156+
)
157+
),
158+
on="id",
159+
how="left",
160+
).with_columns(
161+
pl.col("deepest_child_ids").fill_null([]),
119162
)
163+
assert len(df) == len_0
164+
logger.info(f"CCF structure tree loaded in {time.time() - t0:.2f}s")
165+
return df
120166

121167

122168
def get_ccf_structure_info(ccf_acronym_or_id: str | int) -> dict:
123169
"""
124170
>>> get_ccf_structure_info('MOs')
125-
{'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '0x1F9D5A', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
171+
{'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '#1F9D5A', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
126172
"""
127173
if not isinstance(ccf_acronym_or_id, int):
128174
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
@@ -140,7 +186,53 @@ def get_ccf_structure_info(ccf_acronym_or_id: str | int) -> dict:
140186
return results[0].limit(1).to_dicts()[0]
141187

142188

189+
def get_all_parents(ccf_acronym_or_id: str | int) -> list[str]:
190+
"""
191+
>>> get_all_parents('MOs2/3')
192+
['root', 'grey', 'CH', 'CTX', 'CTXpl', 'Isocortex', 'MO', 'MOs']
193+
"""
194+
info = get_ccf_structure_info(ccf_acronym_or_id)
195+
parent_ids = [int(id_) for id_ in info["structure_id_path"].split("/")[1:-2]]
196+
parent_acronyms = (
197+
get_ccf_structure_tree_df()
198+
.filter(
199+
pl.col("id").is_in(parent_ids),
200+
)["acronym"]
201+
.to_list()
202+
)
203+
assert info["id"] not in parent_acronyms
204+
return parent_acronyms
205+
206+
207+
def get_all_children(ccf_acronym_or_id: str | int) -> list[str]:
208+
"""
209+
>>> get_all_children('MOs')
210+
['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
211+
"""
212+
if not isinstance(ccf_acronym_or_id, int):
213+
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
214+
else:
215+
ccf_id = ccf_acronym_or_id
216+
children = (
217+
get_ccf_structure_tree_df()
218+
.filter(
219+
pl.col("structure_id_path").str.contains(f"/{ccf_id}/"),
220+
~pl.col("structure_id_path").str.ends_with(f"/{ccf_id}/"),
221+
)["acronym"]
222+
.to_list()
223+
)
224+
assert str(ccf_acronym_or_id) not in children
225+
return children
226+
227+
143228
def get_deepest_children(ccf_acronym_or_id: str | int) -> list[str]:
229+
"""
230+
>>> get_deepest_children('MOs')
231+
['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
232+
>>> get_deepest_children('MOs1')
233+
['MOs1']
234+
>>> assert 'VISpor' not in get_deepest_children('VIS')
235+
"""
144236
if not isinstance(ccf_acronym_or_id, int):
145237
try:
146238
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
@@ -160,6 +252,38 @@ def get_deepest_children(ccf_acronym_or_id: str | int) -> list[str]:
160252
].to_list()
161253

162254

255+
def group_child_labels_in_slice(
256+
slice_array: npt.NDArray[np.uint32],
257+
acronyms_or_ids: Iterable[str | int],
258+
) -> npt.NDArray[np.uint32]:
259+
"""
260+
For a given slice and CCF areas (acronyms or IDs), return a new slice with the labels grouped,
261+
so that all areas in the same group have the same label. For example, passing ["MOS"] would
262+
change the label index of all child areas in the MOs tree to have the MOs value.
263+
264+
>>> mos_id = get_ccf_structure_info('MOs')['id']
265+
>>> slice_array = get_ccf_volume()[:, :, 100]
266+
>>> assert mos_id not in slice_array
267+
>>> new_slice = group_child_labels_in_slice(slice_array, ['MOs'])
268+
>>> assert mos_id in new_slice
269+
"""
270+
slice_array = slice_array.copy()
271+
for ccf_acronym_or_id in acronyms_or_ids:
272+
if not isinstance(ccf_acronym_or_id, int):
273+
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
274+
else:
275+
ccf_id = ccf_acronym_or_id
276+
children = get_ccf_immediate_children_ids(ccf_id)
277+
if ccf_id in children:
278+
children.remove(ccf_id)
279+
logger.debug(
280+
f"Grouping {children} under {convert_ccf_acronyms_or_ids(ccf_acronym_or_id)}"
281+
)
282+
for child in children:
283+
slice_array[slice_array == child] = ccf_id
284+
return slice_array
285+
286+
163287
def get_ccf_immediate_children_ids(ccf_acronym_or_id: str | int) -> set[int]:
164288
"""
165289
>>> ids = get_ccf_immediate_children_ids('MOs')
@@ -476,13 +600,58 @@ def get_scatter_image(
476600
return image
477601

478602

603+
def project_first_nonzero_labels(
604+
volume: npt.NDArray,
605+
axis: int = AXIS_TO_DIM["dv"],
606+
) -> npt.NDArray:
607+
"""
608+
Project the first non-zero label encountered from one side of the 3D volume.
609+
610+
Parameters:
611+
volume (np.ndarray): 3D array containing non-zero area labels.
612+
axis (int): Axis along which to project (0, 1, or 2).
613+
614+
Returns:
615+
np.ndarray: 2D array with the projected labels.
616+
"""
617+
if volume.ndim != 3:
618+
raise ValueError(f"Volume must be 3D: {volume.shape=}")
619+
dims = tuple(range(volume.ndim))
620+
if axis not in dims:
621+
raise ValueError("Axis must be 0, 1, or 2.")
622+
plane_dims = [d for d in dims if d != axis]
623+
mask = volume > 0
624+
idx_along_projection_axis = np.argmax(mask, axis=axis)
625+
idx_in_plane_axes = [np.arange(volume.shape[d]) for d in plane_dims]
626+
if axis == 0:
627+
projection = volume[
628+
idx_along_projection_axis,
629+
idx_in_plane_axes[0][:, None],
630+
idx_in_plane_axes[1],
631+
]
632+
elif axis == 1:
633+
projection = volume[
634+
idx_in_plane_axes[0][:, None],
635+
idx_along_projection_axis,
636+
idx_in_plane_axes[1],
637+
]
638+
elif axis == 2:
639+
projection = volume[
640+
idx_in_plane_axes[0][:, None],
641+
idx_in_plane_axes[1],
642+
idx_along_projection_axis,
643+
]
644+
projection = projection.astype(float)
645+
projection[projection == 0] = np.nan
646+
return projection
647+
648+
479649
if __name__ == "__main__":
480650
logging.basicConfig(
481-
level=logging.DEBUG,
651+
level=logging.WARNING,
482652
format="%(asctime)s | %(name)s | %(levelname)s | %(funcName)s | %(message)s",
483653
datefmt="%d-%b-%y %H:%M:%S",
484654
)
485-
logging.getLogger().setLevel(logging.DEBUG)
486655

487656
import doctest
488657

0 commit comments

Comments
 (0)