Skip to content

Commit 7272261

Browse files
TeunHuijbenJoOkuma
andauthored
updated to_geff export function and added cli (#248)
* updated to_geff export function and added cli * Update ultrack/cli/export.py Co-authored-by: Jordão Bragantini <jordao.bragantini@czbiohub.org> * handle Jordaos review --------- Co-authored-by: Jordão Bragantini <jordao.bragantini@czbiohub.org>
1 parent a99bc0f commit 7272261

File tree

3 files changed

+209
-51
lines changed

3 files changed

+209
-51
lines changed

ultrack/cli/export.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
tuple_callback,
1616
)
1717
from ultrack.config import MainConfig
18-
from ultrack.core.export import to_ctc, to_trackmate, to_tracks_layer, tracks_to_zarr
18+
from ultrack.core.export import (
19+
to_ctc,
20+
to_geff_from_database,
21+
to_trackmate,
22+
to_tracks_layer,
23+
tracks_to_zarr,
24+
)
1925
from ultrack.core.solve.sqltracking import SQLTracking
2026
from ultrack.imgproc.measure import tracks_properties
2127
from ultrack.utils.data import validate_and_overwrite_path
@@ -187,6 +193,45 @@ def trackmate_cli(
187193
to_trackmate(config, output_path, overwrite)
188194

189195

196+
@click.command("geff")
197+
@click.argument(
198+
"database_path",
199+
type=click.Path(path_type=Path, exists=True),
200+
)
201+
@click.option(
202+
"--output-path",
203+
"-o",
204+
required=False,
205+
type=click.Path(path_type=Path),
206+
default=None,
207+
help=(
208+
"Geff (Graph Exchange File Format) output path. "
209+
"If not provided, saves to same directory as database with '_geff.geff' extension."
210+
),
211+
)
212+
@overwrite_option()
213+
def geff_cli(
214+
database_path: Path,
215+
output_path: Optional[Path],
216+
overwrite: bool,
217+
) -> None:
218+
"""
219+
Exports tracking results to Geff (Graph Exchange File Format) format.
220+
"""
221+
if output_path is None:
222+
# Generate output path from database path
223+
output_path = database_path.parent / f"{database_path.stem}.geff"
224+
else:
225+
# Validate that the output path has a geff extension
226+
output_str = str(output_path)
227+
if not (output_str.endswith(".geff") or output_str.endswith(".geff.zarr")):
228+
raise click.BadParameter(
229+
f"Output path must have a .geff or .geff.zarr extension, got: {output_path}"
230+
)
231+
232+
to_geff_from_database(database_path, output_path, overwrite)
233+
234+
190235
@click.command("lp")
191236
@click.option(
192237
"--output-path",
@@ -229,6 +274,7 @@ def export_cli() -> None:
229274

230275

231276
export_cli.add_command(ctc_cli)
277+
export_cli.add_command(geff_cli)
232278
export_cli.add_command(lp_cli)
233279
export_cli.add_command(trackmate_cli)
234280
export_cli.add_command(zarr_napari_cli)

ultrack/core/export/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ultrack.core.export.ctc import to_ctc
22
from ultrack.core.export.exporter import export_tracks_by_extension
3-
from ultrack.core.export.geff import to_geff
3+
from ultrack.core.export.geff import to_geff, to_geff_from_database
44
from ultrack.core.export.networkx import to_networkx, tracks_layer_to_networkx
55
from ultrack.core.export.trackmate import to_trackmate, tracks_layer_to_trackmate
66
from ultrack.core.export.tracks_layer import to_tracks_layer

ultrack/core/export/geff.py

Lines changed: 161 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,29 @@
1515
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB, OverlapDB
1616

1717

18-
def to_geff(
19-
config: MainConfig,
18+
# Helper function to convert pandas/numpy dtypes to string dtype names
19+
def dtype_to_str(dtype) -> str:
20+
"""Convert pandas/numpy dtype to string dtype name for PropMetadata."""
21+
# Convert to numpy dtype first to get consistent .name attribute
22+
np_dtype = np.dtype(dtype)
23+
dtype_name = np_dtype.name
24+
25+
# Most dtypes work directly (int64, float64, bool, etc.)
26+
return dtype_name
27+
28+
29+
def to_geff_from_database(
30+
database_path: Union[str, Path],
2031
filename: Union[str, Path],
2132
overwrite: bool = False,
2233
) -> None:
2334
"""
24-
Export tracks to a geff (Graph Exchange File Format) file.
35+
Export tracks to a geff (Graph Exchange File Format) file from a database.
2536
2637
Parameters
2738
----------
28-
config : MainConfig
29-
The configuration object.
39+
database_path : str or Path
40+
The path to the database file.
3041
filename : str or Path
3142
The name of the file to save the tracks to.
3243
overwrite : bool, optional
@@ -46,9 +57,39 @@ def to_geff(
4657
else:
4758
shutil.rmtree(filename)
4859

49-
engine = sqla.create_engine(config.data_config.database_path)
60+
# Convert database_path to SQLAlchemy URL format if needed
61+
database_path_str = str(database_path)
62+
# If it's not already a SQLAlchemy URL (doesn't start with a protocol), assume it's a SQLite file path
63+
if not database_path_str.startswith(
64+
("sqlite://", "postgresql://", "mysql://", "postgresql+psycopg2://")
65+
):
66+
# Convert file path to SQLite URL format
67+
database_path_str = f"sqlite:///{Path(database_path).absolute()}"
68+
engine = sqla.create_engine(database_path_str)
5069
with Session(engine) as session:
51-
node_stmt = session.query(
70+
# Collect nodes data, storing masks and bboxes separately
71+
all_nodes_data = []
72+
all_masks = []
73+
all_bboxes = []
74+
solution_source = []
75+
solution_target = []
76+
77+
for (
78+
node_id,
79+
t,
80+
parent_id,
81+
z,
82+
y,
83+
x,
84+
z_shift,
85+
y_shift,
86+
x_shift,
87+
area,
88+
frontier,
89+
height,
90+
selected,
91+
pickle_obj,
92+
) in session.query(
5293
NodeDB.id,
5394
NodeDB.t,
5495
NodeDB.parent_id,
@@ -63,59 +104,96 @@ def to_geff(
63104
NodeDB.height,
64105
NodeDB.selected,
65106
NodeDB.pickle,
66-
).statement
67-
node_df = pd.read_sql(node_stmt, session.bind, index_col="id")
68-
node_df["id"] = node_df.index
107+
):
108+
node_dict = {
109+
"id": node_id,
110+
"parent_id": parent_id,
111+
"t": t,
112+
"z": z,
113+
"y": y,
114+
"x": x,
115+
"z_shift": z_shift,
116+
"y_shift": y_shift,
117+
"x_shift": x_shift,
118+
"area": area,
119+
"frontier": frontier,
120+
"height": height,
121+
"solution": selected,
122+
}
123+
all_nodes_data.append(node_dict)
124+
# Store masks and bboxes separately
125+
all_masks.append(pickle_obj.mask.astype(np.uint64))
126+
all_bboxes.append(pickle_obj.bbox.astype(np.int64))
127+
128+
# Collect solution edges (parent-child relationships)
129+
if selected and parent_id != NO_PARENT:
130+
solution_source.append(parent_id)
131+
solution_target.append(node_id)
132+
133+
# Create nodes dataframe (only scalar values, no pickle objects)
134+
node_df = pd.DataFrame(all_nodes_data)
135+
node_df.set_index("id", inplace=True)
136+
node_df["solution"] = node_df["solution"].astype(bool)
69137

138+
# Query edges
70139
edge_stmt = session.query(
71140
LinkDB.source_id, LinkDB.target_id, LinkDB.weight
72141
).statement
73142
edge_df = pd.read_sql(edge_stmt, session.bind)
74143

75-
sol_links_df = node_df.loc[
76-
node_df["selected"] & node_df["parent_id"] != NO_PARENT,
77-
["id", "parent_id"],
78-
]
79-
sol_links_df = sol_links_df.rename(
80-
columns={"parent_id": "source_id", "id": "target_id"},
144+
# Add solution column to edges
145+
sol_links_df = pd.DataFrame(
146+
{
147+
"source_id": solution_source,
148+
"target_id": solution_target,
149+
"solution": True,
150+
}
81151
)
82-
sol_links_df["solution"] = True
83-
edge_df = edge_df.merge(sol_links_df, on=["source_id", "target_id"])
84-
edge_df["solution"] = edge_df["solution"].fillna(False)
85-
86-
node_df.rename(columns={"selected": "solution"}, inplace=True)
87-
node_df.drop(["id", "parent_id"], axis=1, inplace=True)
152+
edge_df = edge_df.merge(sol_links_df, on=["source_id", "target_id"], how="left")
153+
edge_df.loc[edge_df["solution"].isna(), "solution"] = False
154+
edge_df["solution"] = edge_df["solution"].astype(bool)
155+
if "weight" in edge_df.columns:
156+
edge_df["weight"] = edge_df["weight"].astype(np.float64)
88157

158+
# Query overlaps
89159
overlap_stmt = session.query(
90160
OverlapDB.node_id,
91161
OverlapDB.ancestor_id,
92162
).statement
93163
overlap_df = pd.read_sql(overlap_stmt, session.bind)
94164

95-
node_props_metadata = {
96-
c: PropMetadata(
165+
# Create node properties metadata
166+
node_props_metadata = {}
167+
for c in node_df.columns:
168+
node_props_metadata[c] = PropMetadata(
97169
identifier=c,
98-
dtype=node_df[c].dtype,
170+
dtype=dtype_to_str(node_df[c].dtype),
99171
)
100-
for c in node_df.columns
101-
if c != "pickle"
102-
}
103172
node_props_metadata["mask"] = PropMetadata(
104173
identifier="mask",
105-
dtype=np.uint64,
174+
dtype="uint64",
106175
varlength=True,
107176
)
108177
node_props_metadata["bbox"] = PropMetadata(
109178
identifier="bbox",
110-
dtype=np.int64,
179+
dtype="int64",
111180
)
112181

113-
edge_ids = edge_df[["source_id", "target_id"]].to_numpy(dtype=np.uint64)
182+
# Prepare edge IDs and properties
183+
edge_ids = np.column_stack(
184+
[
185+
edge_df["source_id"].to_numpy(dtype=np.uint64),
186+
edge_df["target_id"].to_numpy(dtype=np.uint64),
187+
]
188+
)
114189
edge_df = edge_df.drop(columns=["source_id", "target_id"])
115190

116-
edge_props_metadata = {
117-
c: PropMetadata(identifier=c, dtype=edge_df[c].dtype) for c in edge_df.columns
118-
}
191+
# Create edge properties metadata
192+
edge_props_metadata = {}
193+
for c in edge_df.columns:
194+
edge_props_metadata[c] = PropMetadata(
195+
identifier=c, dtype=dtype_to_str(edge_df[c].dtype)
196+
)
119197

120198
geff_metadata = geff.GeffMetadata(
121199
directed=True,
@@ -129,27 +207,32 @@ def to_geff(
129207
edge_props_metadata=edge_props_metadata,
130208
)
131209

132-
node_props = {
133-
c: {"values": node_df[c].to_numpy(), "missing": None}
134-
for c in node_df.columns
135-
if c != "pickle"
136-
}
137-
node_props["mask"] = construct_var_len_props(
138-
[v.mask.astype(np.uint64) for v in node_df["pickle"]]
139-
)
140-
node_props["bbox"] = {
141-
"values": np.stack([v.bbox for v in node_df["pickle"]]),
142-
"missing": None,
143-
}
210+
# Prepare node properties (using separately stored masks and bboxes)
211+
node_props = {}
212+
for c in node_df.columns:
213+
# Convert to appropriate numpy dtype
214+
values = node_df[c].to_numpy()
215+
node_props[c] = {"values": values, "missing": None}
216+
217+
# Handle mask - use the separately stored masks
218+
node_props["mask"] = construct_var_len_props(all_masks)
219+
220+
# Handle bbox - stack into 2D array from separately stored bboxes
221+
bbox_array = np.stack(all_bboxes)
222+
node_props["bbox"] = {"values": bbox_array, "missing": None}
223+
224+
# Prepare edge properties with proper dtypes
225+
edge_props = {}
226+
for c in edge_df.columns:
227+
values = edge_df[c].to_numpy()
228+
edge_props[c] = {"values": values, "missing": None}
144229

145230
write_arrays(
146231
filename,
147232
node_ids=node_df.index.to_numpy(dtype=np.uint64),
148233
node_props=node_props,
149234
edge_ids=edge_ids,
150-
edge_props={
151-
c: {"values": edge_df[c].to_numpy(), "missing": None} for c in edge_df
152-
},
235+
edge_props=edge_props,
153236
metadata=geff_metadata,
154237
)
155238

@@ -159,3 +242,32 @@ def to_geff(
159242
dtype=np.uint64
160243
)
161244
store.create_group("overlaps/props")
245+
246+
247+
def to_geff(
248+
config: MainConfig,
249+
filename: Union[str, Path],
250+
overwrite: bool = False,
251+
) -> None:
252+
"""
253+
Export tracks to a geff (Graph Exchange File Format) file.
254+
255+
Parameters
256+
----------
257+
config : MainConfig
258+
The configuration object.
259+
filename : str or Path
260+
The name of the file to save the tracks to.
261+
overwrite : bool, optional
262+
Whether to overwrite the file if it already exists, by default False.
263+
264+
Raises
265+
------
266+
FileExistsError
267+
If the file already exists and overwrite is False.
268+
"""
269+
to_geff_from_database(
270+
database_path=config.data_config.database_path,
271+
filename=filename,
272+
overwrite=overwrite,
273+
)

0 commit comments

Comments
 (0)