Skip to content

Commit 7de50ca

Browse files
JoOkumaTeunHuijben
andauthored
Updating ground-truth matching and adding default classifier (#202)
* added save_config function * added volume attribute to node * close_tracks_gaps now works for segments being a dask_array * removed node._volume method, since precomputed area already exists * added UltrackArray class to utils/array.py * ran pre-commit * added path to database as optional input parameter. If provided, the database is taken from the provided path. If not provided, the database is loaded from the path stored in config * hierarchy widget, work in progress * added non-linear slider for uniform addition of volumes' * added documentation to ultrack-array and HierarchyVizWidget * Update ultrack/widgets/ultrackwidget/ultrackwidget.py Co-authored-by: Jordão Bragantini <[email protected]> * Update ultrack/widgets/ultrackwidget/ultrackwidget.py Co-authored-by: Jordão Bragantini <[email protected]> * Update ultrack/utils/array.py Co-authored-by: Jordão Bragantini <[email protected]> * implementing Jordaos revisions * revising PR * renaming functions in ultrack-array * reverting dirty history changes * redid the changes lost in merge * added indexing test for ultrack-array * added test for hierarchy widget * added option to predict node probabilities * fixed spelling 'persistense -> persistence' * adding match ground-truth CLI * fixed ground-truth match test * bypassing edge case where there isn't any competing segmentation * replacing xgboost with catboost * added additional shape sanity check * fixing right subset for fit * Update ultrack/cli/_test/test_cli.py * added batch index option to gt matcher * adding option to use ground for pinning ILP variables * added kwargs option to main function * Update ultrack/ml/classification.py * Update ultrack/utils/array.py Co-authored-by: Jordão Bragantini <[email protected]> * move UltrackArray to separate file to prevent circular inputs * implemented jordaos review comments * precommit * WIP adding new link features classifier * added ml to pixi * adding link classifications modules * add additional checking * adding link classifier debugging code * remove overlap bug fixes * updating docs and typing * splitting testing * fixing widget testing * adding constant to layer name * Removing constant 5 * refactoring ultrack array to dynamics load any column * added testing to variable column attribute * updated UI and added offset * minor test and combobox changes * renaming vol to num_pix * add scale from ultrack metadata * changing link df update probability query * added missing default value to empty match gt alignment * adding caching * bug fix * fix sqlalchemy query binding * adding ultrack array gtlink column support * fixed multi time point loading * changed matching policy to use intersection * adding template matching * fixing ILP * ILP with template matching * ignoring catboost files * fixing deprecations, logging and typing --------- Co-authored-by: TeunHuijben <[email protected]> Co-authored-by: Teun Huijben <[email protected]> Co-authored-by: TeunHuijben <[email protected]>
1 parent 45b9744 commit 7de50ca

25 files changed

+1254
-76
lines changed

Diff for: .gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ data.db
181181
# pixi environments
182182
.pixi
183183
*.egg-info
184+
185+
# catboost files
186+
catboost_info

Diff for: pyproject.toml

+12-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ test = [
7272
"napari[testing] >0.4.18",
7373
"pyqt5 >=5.15.4",
7474
]
75+
ml = [
76+
"catboost >=1.2.7,<1.3",
77+
"scikit-learn >=1.6.0,<1.7",
78+
]
7579

7680
[project.scripts]
7781
ultrack = "ultrack.cli.main:main"
@@ -120,6 +124,8 @@ uvicorn = ">=0.27.0.post1"
120124
websocket = ">=0.2.1"
121125
websockets = ">=12.0"
122126
zarr = ">=2.15.0,<3.0.0"
127+
scikit-learn = ">=1.6.0,<1.7"
128+
catboost = ">=1.2.7,<1.3"
123129
pyarrow = ">=16.1.0,<20"
124130

125131
[tool.pixi.feature.cuda]
@@ -142,6 +148,10 @@ pytest-qt = ">=4.4.0,<5"
142148
pyqt = ">=5.15.9,<6"
143149
pytest-cov = ">=6.0.0,<7"
144150

151+
[tool.pixi.feature.ml.dependencies]
152+
catboost = ">=1.2.7,<1.3"
153+
scikit-learn = ">=1.6.0,<1.7"
154+
145155
[tool.pytest.ini_options]
146156
filterwarnings = [
147157
"ignore::DeprecationWarning:pkg_resources.*:",
@@ -154,7 +164,8 @@ ultrack = { path = ".", editable = true }
154164
default = { solve-group = "default" }
155165
cuda = { features = ["cuda"] }
156166
# docs = { features = ["docs"]} # Current dependencies aren't compatible with pixi
157-
test = { features = ["test"], solve-group = "default" }
167+
ml = { features = ["ml"], solve-group = "default" }
168+
test = { features = ["test", "ml"], solve-group = "default" }
158169

159170
[tool.pixi.feature.test.tasks]
160171
test = "pytest -v --color=yes --cov=ultrack --cov-report=html --durations=15 ."

Diff for: ultrack/cli/_test/test_cli.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def test_segment(
6565
"foreground",
6666
"-cl",
6767
"contours",
68+
"-il",
69+
"labels",
70+
"-il",
71+
"contours",
72+
"-p",
73+
"intensity_mean",
6874
]
6975
+ zarr_dataset_paths
7076
)
@@ -86,6 +92,50 @@ def test_link_with_images(
8692
["link", "-cfg", str(instance_config_path), "-ow"] + zarr_dataset_paths[:2]
8793
)
8894

95+
def test_fit_and_add_probs(
96+
self, instance_config_path: str, tmp_path: Path, zarr_dataset_paths: List[str]
97+
) -> None:
98+
# required by match gt with model output
99+
pytest.importorskip("catboost")
100+
pytest.importorskip("sklearn")
101+
102+
model_path = tmp_path / "model.pkl"
103+
new_cfg_path = tmp_path / "new_config.toml"
104+
105+
_run_command(
106+
[
107+
"match_gt",
108+
"-cfg",
109+
instance_config_path,
110+
"-gl",
111+
"labels",
112+
"-om",
113+
str(model_path),
114+
"-oc",
115+
str(new_cfg_path),
116+
"--is-tracking",
117+
"--is-segmentation",
118+
"--persistence",
119+
]
120+
+ zarr_dataset_paths
121+
)
122+
123+
# testing loading new config
124+
load_config(new_cfg_path)
125+
126+
for var in ["nodes", "links"]:
127+
_run_command(
128+
[
129+
"add_probs",
130+
str(model_path),
131+
"-cfg",
132+
instance_config_path,
133+
"--persistence",
134+
"--var",
135+
var,
136+
]
137+
)
138+
89139
def test_solve(self, instance_config_path: str) -> None:
90140
with pytest.warns(UserWarning):
91141
# batch index with overwrite should trigger warning
@@ -155,7 +205,7 @@ def test_zarr_napari_export(
155205
]
156206
)
157207

158-
@pytest.mark.parametrize("mode", ["solutions", "links", "all"])
208+
@pytest.mark.parametrize("mode", ["gt", "solutions", "links", "all"])
159209
def test_clear_database(self, instance_config_path: str, mode: str) -> None:
160210
_run_command(
161211
[

Diff for: ultrack/cli/clear_database.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from ultrack.config.config import MainConfig
55
from ultrack.core.database import clear_all_data
66
from ultrack.core.linking.utils import clear_linking_data
7+
from ultrack.core.match_gt import clear_ground_truths
78
from ultrack.core.solve.sqltracking import SQLTracking
89

910

1011
@click.command("clear_database")
11-
@click.argument("mode", type=click.Choice(["all", "links", "solutions"]))
12+
@click.argument("mode", type=click.Choice(["all", "links", "solutions", "gt"]))
1213
@config_option()
1314
def clear_database_cli(mode: str, config: MainConfig) -> None:
1415
"""Cleans database content."""
@@ -20,5 +21,7 @@ def clear_database_cli(mode: str, config: MainConfig) -> None:
2021
clear_linking_data(database_path)
2122
elif mode == "solutions":
2223
SQLTracking.clear_solution_from_database(database_path)
24+
elif mode == "gt":
25+
clear_ground_truths(database_path)
2326
else:
2427
raise NotImplementedError(f"Clear database mode {mode} not implemented.")

Diff for: ultrack/cli/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from ultrack.cli.flow import add_flow_cli
1010
from ultrack.cli.labels_to_edges import labels_to_contours_cli
1111
from ultrack.cli.link import link_cli
12+
from ultrack.cli.match_gt import match_gt_cli
13+
from ultrack.cli.predict import add_probs_cli
1214
from ultrack.cli.segment import segmentation_cli
1315
from ultrack.cli.server import server_cli
1416
from ultrack.cli.solve import solve_cli
@@ -29,6 +31,8 @@ def main():
2931
main.add_command(export_cli)
3032
main.add_command(labels_to_contours_cli)
3133
main.add_command(link_cli)
34+
main.add_command(match_gt_cli)
35+
main.add_command(add_probs_cli)
3236
main.add_command(segmentation_cli)
3337
main.add_command(solve_cli)
3438
main.add_command(server_cli)

Diff for: ultrack/cli/match_gt.py

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Optional, Sequence
4+
5+
import click
6+
import cloudpickle
7+
import toml
8+
from napari.plugins import _initialize_plugins
9+
from napari.viewer import ViewerModel
10+
from rich.logging import RichHandler
11+
12+
from ultrack.cli.segment import _get_layer_data
13+
from ultrack.cli.utils import (
14+
batch_index_option,
15+
config_option,
16+
napari_reader_option,
17+
overwrite_option,
18+
paths_argument,
19+
persistence_option,
20+
)
21+
from ultrack.config import MainConfig
22+
from ultrack.core.match_gt import match_to_ground_truth
23+
from ultrack.ml.classification import fit_links_prob, fit_nodes_prob
24+
25+
LOG = logging.getLogger(__name__)
26+
LOG.setLevel(logging.INFO)
27+
LOG.addHandler(RichHandler())
28+
29+
30+
@click.command("match_gt")
31+
@paths_argument()
32+
@napari_reader_option()
33+
@config_option()
34+
@click.option(
35+
"--ground-truth-layer",
36+
"-gl",
37+
required=False,
38+
type=str,
39+
default=None,
40+
help="Ground-truth layer index on napari.",
41+
)
42+
@click.option(
43+
"--output-model",
44+
"-om",
45+
type=click.Path(dir_okay=False, path_type=Path),
46+
required=False,
47+
default=None,
48+
help="Optional output model file path.",
49+
)
50+
@click.option(
51+
"--output-config",
52+
"-oc",
53+
type=click.Path(dir_okay=False, path_type=Path),
54+
help="Optional output config file path.",
55+
)
56+
@click.option(
57+
"--is-segmentation",
58+
is_flag=True,
59+
type=bool,
60+
default=False,
61+
help="Indicates ground-truth are fully curated segmentation masks. "
62+
"When activated different costs are used for insertions and deletions.",
63+
)
64+
@click.option(
65+
"--is-tracking",
66+
is_flag=True,
67+
type=bool,
68+
default=False,
69+
help="Indicates ground-truth are tracking instances results.",
70+
)
71+
@click.option(
72+
"--is-dense",
73+
is_flag=True,
74+
type=bool,
75+
default=False,
76+
help="Indicates ground-truth are dense annotations (everything is annotated).",
77+
)
78+
@click.option(
79+
"--insert-prob",
80+
is_flag=True,
81+
type=bool,
82+
default=False,
83+
help="Insert estimated probabilities into the database.",
84+
)
85+
@batch_index_option()
86+
@overwrite_option()
87+
@persistence_option()
88+
def match_gt_cli(
89+
paths: Sequence[Path],
90+
reader_plugin: str,
91+
config: MainConfig,
92+
ground_truth_layer: Optional[str],
93+
output_model: Optional[Path],
94+
output_config: Optional[Path],
95+
is_segmentation: bool,
96+
is_tracking: bool,
97+
is_dense: bool,
98+
insert_prob: bool,
99+
batch_index: Optional[int],
100+
overwrite: bool,
101+
persistence: bool,
102+
) -> None:
103+
"""
104+
Match ground-truth labels to the segmentation/tracking database.
105+
"""
106+
107+
if output_model is not None and output_model.exists() and not overwrite:
108+
raise FileExistsError(
109+
f"Output model {output_model} already exists. Use --overwrite to overwrite."
110+
)
111+
112+
if output_config is not None:
113+
if not is_segmentation:
114+
raise ValueError(
115+
"Output config is only available for segmentation ground-truth `--is-segmentation`."
116+
)
117+
118+
if output_config.exists() and not overwrite:
119+
raise FileExistsError(
120+
f"Output config {output_config} already exists. Use --overwrite to overwrite."
121+
)
122+
123+
# Data loading
124+
_initialize_plugins()
125+
126+
viewer = ViewerModel()
127+
viewer.open(path=paths, plugin=reader_plugin)
128+
129+
if ground_truth_layer is None:
130+
if len(viewer.layers) > 1:
131+
raise ValueError(
132+
"Multiple layers found, please specify `--ground-truth-layer`."
133+
)
134+
else:
135+
ground_truth_layer = viewer.layers[0].name
136+
137+
gt = _get_layer_data(viewer, ground_truth_layer)
138+
139+
# Match ground-truth to database
140+
gt_df, new_config = match_to_ground_truth(
141+
config=config,
142+
gt_labels=gt,
143+
scale=config.data_config.metadata.get("scale"),
144+
is_segmentation=is_segmentation,
145+
optimize_config=True,
146+
batch_index=batch_index,
147+
)
148+
149+
if output_config is not None:
150+
LOG.info("Estimated new config: %s", new_config)
151+
LOG.info("Saving new config to %s", output_config)
152+
with open(output_config, "w") as f:
153+
toml.dump(new_config.model_dump(by_alias=True), f)
154+
155+
if insert_prob or output_model is not None:
156+
model = fit_nodes_prob(
157+
config,
158+
gt_df["gt_track_id"],
159+
persistence_features=persistence,
160+
insert_prob=insert_prob,
161+
remove_no_overlap=not is_dense,
162+
)
163+
164+
if is_tracking:
165+
link_model = fit_links_prob(
166+
config,
167+
gt_df["gt_track_id"],
168+
persistence_features=persistence,
169+
insert_prob=insert_prob,
170+
)
171+
model = {"nodes": model, "links": link_model}
172+
173+
if output_model is not None:
174+
LOG.info("Saving model to %s", output_model)
175+
with open(output_model, "wb") as f:
176+
cloudpickle.dump(model, f)

Diff for: ultrack/cli/predict.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from pathlib import Path
2+
from typing import Literal
3+
4+
import click
5+
from cloudpickle import load
6+
7+
from ultrack.cli.utils import config_option, persistence_option
8+
from ultrack.config import MainConfig
9+
from ultrack.ml.classification import predict_links_prob, predict_nodes_prob
10+
11+
12+
@click.command("add_probs")
13+
@click.argument("classif_pickle_path", type=click.Path(exists=True, path_type=Path))
14+
@click.option(
15+
"--var",
16+
type=click.Choice(["nodes", "links", "divisions", "appearances", "disappearances"]),
17+
default="nodes",
18+
help="Variable to assign probabilities.",
19+
)
20+
@config_option()
21+
@persistence_option()
22+
def add_probs_cli(
23+
classif_pickle_path: Path,
24+
var: Literal["nodes", "links", "divisions", "appearances", "disappearances"],
25+
config: MainConfig,
26+
persistence: bool,
27+
) -> None:
28+
"""Predicts and adds nodes' probabilities to the database."""
29+
30+
with open(classif_pickle_path, "rb") as f:
31+
classifier = load(f)
32+
if isinstance(classifier, dict):
33+
classifier = classifier[var]
34+
35+
if var == "nodes":
36+
predict_nodes_prob(config, classifier, persistence_features=persistence)
37+
elif var == "links":
38+
predict_links_prob(config, classifier, persistence_features=persistence)
39+
else:
40+
# TODO add edges and other probabilities
41+
raise NotImplementedError(f"Variable {var} not implemented.")

0 commit comments

Comments
 (0)