Skip to content

Commit e19f8a7

Browse files
committed
Re-locate functions
1 parent e0e9484 commit e19f8a7

File tree

7 files changed

+327
-312
lines changed

7 files changed

+327
-312
lines changed

flamingo_tools/measurements.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from concurrent import futures
88
from functools import partial
9+
from multiprocessing import cpu_count
910
from typing import List, Optional, Tuple, Union
1011

1112
import numpy as np
@@ -502,3 +503,102 @@ def _compute_block(block_id):
502503

503504
mask = ResizedVolume(low_res_mask, shape=original_shape, order=0)
504505
return mask
506+
507+
508+
def object_measures_single(
509+
table_path: str,
510+
seg_path: str,
511+
image_paths: List[str],
512+
out_paths: List[str],
513+
force_overwrite: bool = False,
514+
component_list: List[int] = [1],
515+
background_mask: Optional[np.typing.ArrayLike] = None,
516+
resolution: List[float] = [0.38, 0.38, 0.38],
517+
s3: bool = False,
518+
s3_credentials: Optional[str] = None,
519+
s3_bucket_name: Optional[str] = None,
520+
s3_service_endpoint: Optional[str] = None,
521+
**_
522+
):
523+
"""Compute object measures for a single or multiple image channels in respect to a single segmentation channel.
524+
525+
Args:
526+
table_path: File path to segmentationt table.
527+
seg_path: Path to segmentation channel in ome.zarr format.
528+
image_paths: Path(s) to image channel(s) in ome.zarr format.
529+
out_paths: Paths(s) for calculated object measures.
530+
force_overwrite: Forcefully overwrite existing files.
531+
component_list: Only calculate object measures for specific components.
532+
background_mask: Use background mask for calculating object measures.
533+
resolution: Resolution of input in micrometer.
534+
s3: Use S3 file paths.
535+
s3_credentials:
536+
s3_bucket_name:
537+
s3_service_endpoint:
538+
"""
539+
input_key = "s0"
540+
out_paths = [os.path.realpath(o) for o in out_paths]
541+
542+
if not isinstance(resolution, float):
543+
if len(resolution) == 1:
544+
resolution = resolution * 3
545+
assert len(resolution) == 3
546+
resolution = np.array(resolution)[::-1]
547+
else:
548+
resolution = (resolution,) * 3
549+
550+
for (img_path, out_path) in zip(image_paths, out_paths):
551+
n_threads = int(os.environ.get("SLURM_CPUS_ON_NODE", cpu_count()))
552+
553+
# overwrite input file
554+
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
555+
force_overwrite = True
556+
557+
if os.path.isfile(out_path) and not force_overwrite:
558+
print(f"Skipping {out_path}. Table already exists.")
559+
560+
else:
561+
if background_mask is None:
562+
feature_set = "default"
563+
dilation = None
564+
median_only = False
565+
else:
566+
print("Using background mask for calculating object measures.")
567+
feature_set = "default_background_subtract"
568+
dilation = 4
569+
median_only = True
570+
571+
if s3:
572+
img_path, fs = s3_utils.get_s3_path(img_path, bucket_name=s3_bucket_name,
573+
service_endpoint=s3_service_endpoint,
574+
credential_file=s3_credentials)
575+
seg_path, fs = s3_utils.get_s3_path(seg_path, bucket_name=s3_bucket_name,
576+
service_endpoint=s3_service_endpoint,
577+
credential_file=s3_credentials)
578+
579+
mask_cache_path = os.path.join(os.path.dirname(out_path), "bg-mask.zarr")
580+
background_mask = compute_sgn_background_mask(
581+
image_path=img_path,
582+
segmentation_path=seg_path,
583+
image_key=input_key,
584+
segmentation_key=input_key,
585+
n_threads=n_threads,
586+
cache_path=mask_cache_path,
587+
)
588+
589+
compute_object_measures(
590+
image_path=img_path,
591+
segmentation_path=seg_path,
592+
segmentation_table_path=table_path,
593+
output_table_path=out_path,
594+
image_key=input_key,
595+
segmentation_key=input_key,
596+
feature_set=feature_set,
597+
s3_flag=s3,
598+
component_list=component_list,
599+
dilation=dilation,
600+
median_only=median_only,
601+
background_mask=background_mask,
602+
n_threads=n_threads,
603+
resolution=resolution,
604+
)

flamingo_tools/postprocessing/cochlea_mapping.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import os
23
from typing import List, Optional, Tuple
34

45
import networkx as nx
@@ -8,6 +9,7 @@
89
from scipy.interpolate import interp1d
910

1011
from flamingo_tools.postprocessing.label_components import downscaled_centroids
12+
from flamingo_tools.s3_utils import get_s3_path
1113

1214

1315
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
@@ -750,8 +752,8 @@ def tonotopic_mapping(
750752
apex_higher: bool = True,
751753
otof: bool = False,
752754
) -> pd.DataFrame:
753-
"""Tonotopic mapping of IHCs by supplying a table with component labels.
754-
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
755+
"""Tonotopic mapping of SGNs or IHCs by supplying a table with component labels.
756+
The mapping assigns a tonotopic label to each instance according to the position along the length of the cochlea.
755757
756758
Args:
757759
table: Dataframe of segmentation table.
@@ -816,3 +818,74 @@ def tonotopic_mapping(
816818
table = map_frequency(table, animal=animal, otof=otof)
817819

818820
return table
821+
822+
823+
def tonotopic_mapping_single(
824+
table_path: str,
825+
out_path: str,
826+
force_overwrite: bool = False,
827+
cell_type: str = "sgn",
828+
animal: str = "mouse",
829+
otof: bool = False,
830+
apex_position: str = "apex_higher",
831+
component_list: List[int] = [1],
832+
component_mapping: Optional[List[int]] = None,
833+
max_edge_distance: float = 30,
834+
s3: bool = False,
835+
s3_credentials: Optional[str] = None,
836+
s3_bucket_name: Optional[str] = None,
837+
s3_service_endpoint: Optional[str] = None,
838+
**_
839+
):
840+
"""Tonotopic mapping of a single cochlea.
841+
Each segmentation instance within a given component list is assigned a frequency[kHz], a run length and an offset.
842+
The components used for the mapping itself can be a subset of the component list to adapt to broken components
843+
along the Rosenthal's canal.
844+
If the cochlea is broken in the direction of the Rosenthal's canal, the components have to be provided in a
845+
continuous order which reflects the positioning within 3D.
846+
The frequency is calculated using the Greenwood function using animal specific parameters.
847+
The orientation of the mapping can be reversed using the apex position in reference to the y-coordinate.
848+
849+
Args:
850+
table_path: File path to segmentation table.
851+
out_path: Output path to segmentation table with new column "component_labels".
852+
force_overwrite: Forcefully overwrite existing output path.
853+
cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc".
854+
animal: Animal for species specific frequency mapping. Either "mouse" or "gerbil".
855+
otof: Use mapping by *Mueller, Hearing Research 202 (2005) 63-73* for OTOF cochleae.
856+
apex_position: Identify position of apex and base. Apex is set to node with higher y-value per default.
857+
component_list: List of components. Can be passed to obtain the number of instances within the component list.
858+
components_mapping: Components to use for tonotopic mapping. Ignore components torn parallel to main canal.
859+
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
860+
s3: Use S3 bucket.
861+
s3_credentials:
862+
s3_bucket_name:
863+
s3_service_endpoint:
864+
"""
865+
if os.path.isdir(out_path):
866+
raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.")
867+
868+
if s3:
869+
tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name,
870+
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
871+
with fs.open(tsv_path, "r") as f:
872+
table = pd.read_csv(f, sep="\t")
873+
else:
874+
table = pd.read_csv(table_path, sep="\t")
875+
876+
apex_higher = (apex_position == "apex_higher")
877+
878+
# overwrite input file
879+
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
880+
force_overwrite = True
881+
882+
if os.path.isfile(out_path) and not force_overwrite:
883+
print(f"Skipping {out_path}. Table already exists.")
884+
885+
else:
886+
table = tonotopic_mapping(table, component_label=component_list, animal=animal,
887+
cell_type=cell_type, component_mapping=component_mapping,
888+
apex_higher=apex_higher, max_edge_distance=max_edge_distance,
889+
otof=otof)
890+
891+
table.to_csv(out_path, sep="\t", index=False)

flamingo_tools/postprocessing/label_components.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import multiprocessing as mp
3+
import os
34
from concurrent import futures
45
from typing import Callable, List, Optional, Tuple
56

@@ -10,6 +11,7 @@
1011
import pandas as pd
1112

1213
from elf.io import open_file
14+
from flamingo_tools.s3_utils import get_s3_path
1315
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
1416
from scipy.sparse import csr_matrix
1517
from scipy.spatial import distance
@@ -673,3 +675,148 @@ def filter_cochlea_volume(
673675
combined_dilated[combined_dilated > 0] = 1
674676

675677
return combined_dilated
678+
679+
680+
def label_custom_components(tsv_table, custom_dict):
681+
"""Label IHC components using multiple post-processing configurations and combine the
682+
results into final components.
683+
The function applies successive post-processing steps defined in a `custom_dic`
684+
configuration. Each entry under `label_dicts` specifies:
685+
- `label_params`: a list of parameter sets. The segmentation is processed once for
686+
each parameter set (e.g., {"min_size": 500, "max_edge_distance": 65, "min_component_length": 5}).
687+
- `components`: lists of label IDs to extract from each corresponding post-processing run.
688+
Label IDs collected from all runs are merged to form the final component (e.g., key "1").
689+
Global filtering is applied using `min_size_global`, and any `missing_ids`
690+
(e.g., 4800 or 4832) are added explicitly to the final component.
691+
Example `custom_dic` structure:
692+
{
693+
"min_size_global": 500,
694+
"missing_ids": [4800, 4832],
695+
"label_dicts": {
696+
"1": {
697+
"label_params": [
698+
{"min_size": 500, "max_edge_distance": 65, "min_component_length": 5},
699+
{"min_size": 400, "max_edge_distance": 45, "min_component_length": 5}
700+
],
701+
"components": [[18, 22], [1, 45, 83]]
702+
}
703+
}
704+
}
705+
706+
Args:
707+
tsv_table: Pandas dataframe of the MoBIE segmentation table.
708+
custom_dict: Custom dictionary featuring post-processing parameters.
709+
710+
Returns:
711+
Pandas dataframe featuring labeled components.
712+
"""
713+
min_size = custom_dict["min_size_global"]
714+
component_labels = [0 for _ in range(len(tsv_table))]
715+
tsv_table.loc[:, "component_labels"] = component_labels
716+
for custom_comp, label_dict in custom_dict["label_dicts"].items():
717+
label_params = label_dict["label_params"]
718+
label_components = label_dict["components"]
719+
720+
combined_label_ids = []
721+
for comp, other_kwargs in zip(label_components, label_params):
722+
tsv_table_tmp = label_components_ihc(tsv_table.copy(), **other_kwargs)
723+
label_ids = list(tsv_table_tmp.loc[tsv_table_tmp["component_labels"].isin(comp), "label_id"])
724+
combined_label_ids.extend(label_ids)
725+
print(f"{comp}", len(combined_label_ids))
726+
727+
combined_label_ids = list(set(combined_label_ids))
728+
729+
tsv_table.loc[tsv_table["label_id"].isin(combined_label_ids), "component_labels"] = int(custom_comp)
730+
731+
tsv_table.loc[tsv_table["n_pixels"] < min_size, "component_labels"] = 0
732+
if "missing_ids" in list(custom_dict.keys()):
733+
for m in custom_dict["missing_ids"]:
734+
tsv_table.loc[tsv_table["label_id"] == m, "component_labels"] = 1
735+
736+
return tsv_table
737+
738+
739+
def label_components_single(
740+
table_path: str,
741+
out_path: str,
742+
force_overwrite: bool = False,
743+
cell_type: str = "sgn",
744+
component_list: List[int] = [1],
745+
max_edge_distance: float = 30,
746+
min_component_length: int = 50,
747+
min_size: int = 1000,
748+
s3: bool = False,
749+
s3_credentials: Optional[str] = None,
750+
s3_bucket_name: Optional[str] = None,
751+
s3_service_endpoint: Optional[str] = None,
752+
custom_dic: Optional[dict] = None,
753+
**_
754+
):
755+
"""Process a single cochlea using one set of parameters or a custom dictionary.
756+
The cochlea is analyzed using graph-connected components
757+
to label segmentation instances that are closer than a given maximal edge distance.
758+
This process acts on an input segmentation table to which a "component_labels" column is added.
759+
Each entry in this column refers to the index of a connected component.
760+
The largest connected component has an index of 1; the others follow in decreasing order.
761+
762+
Args:
763+
table_path: File path to segmentation table.
764+
out_path: Output path to segmentation table with new column "component_labels".
765+
force_overwrite: Forcefully overwrite existing output path.
766+
cell_type: Cell type of the segmentation. Currently supports "sgn" and "ihc".
767+
component_list: List of components. Can be passed to obtain the number of instances within the component list.
768+
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
769+
min_component_length: Minimal length of nodes of connected component. Filtered out if lower.
770+
min_size: Minimal number of pixels for filtering small instances.
771+
s3: Use S3 bucket.
772+
s3_credentials:
773+
s3_bucket_name:
774+
s3_service_endpoint:
775+
custom_dic: Custom dictionary which allows multiple post-processing configurations and combines the
776+
results into final components.
777+
"""
778+
if os.path.isdir(out_path):
779+
raise ValueError(f"Output path {out_path} is a directory. Provide a path to a single output file.")
780+
781+
if s3:
782+
tsv_path, fs = get_s3_path(table_path, bucket_name=s3_bucket_name,
783+
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
784+
with fs.open(tsv_path, "r") as f:
785+
table = pd.read_csv(f, sep="\t")
786+
else:
787+
table = pd.read_csv(table_path, sep="\t")
788+
789+
# overwrite input file
790+
if os.path.realpath(out_path) == os.path.realpath(table_path) and not s3:
791+
force_overwrite = True
792+
793+
if os.path.isfile(out_path) and not force_overwrite:
794+
print(f"Skipping {out_path}. Table already exists.")
795+
796+
else:
797+
if custom_dic is not None:
798+
# use multiple post-processing configurations
799+
tsv_table = label_custom_components(table, custom_dic)
800+
else:
801+
if cell_type == "sgn":
802+
tsv_table = label_components_sgn(table, min_size=min_size,
803+
min_component_length=min_component_length,
804+
max_edge_distance=max_edge_distance)
805+
elif cell_type == "ihc":
806+
tsv_table = label_components_ihc(table, min_size=min_size,
807+
min_component_length=min_component_length,
808+
max_edge_distance=max_edge_distance)
809+
else:
810+
raise ValueError("Choose a supported cell type. Either 'sgn' or 'ihc'.")
811+
812+
custom_comp = len(tsv_table[tsv_table["component_labels"].isin(component_list)])
813+
print(f"Total {cell_type.upper()}s: {len(tsv_table)}")
814+
if component_list == [1]:
815+
print(f"Largest component has {custom_comp} {cell_type.upper()}s.")
816+
else:
817+
for comp in component_list:
818+
num_instances = len(tsv_table[tsv_table["component_labels"] == comp])
819+
print(f"Component {comp} has {num_instances} instances.")
820+
print(f"Custom component(s) have {custom_comp} {cell_type.upper()}s.")
821+
822+
tsv_table.to_csv(out_path, sep="\t", index=False)

flamingo_tools/postprocessing/synapse_per_ihc_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@
4747
"component_list": [2, 1, 3]},
4848
"M_AMD_N97_R": {"synapse_table_name": "synapse_v3_ihc_v4b", "ihc_table_name": "IHC_v4b",
4949
"component_list": [2, 5]},
50-
}
50+
}

0 commit comments

Comments
 (0)