Skip to content

Commit 5824ece

Browse files
committed
updated to_geff export function and added cli
1 parent 71537e2 commit 5824ece

File tree

3 files changed

+214
-51
lines changed

3 files changed

+214
-51
lines changed

ultrack/cli/export.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
tuple_callback,
1616
)
1717
from ultrack.config import MainConfig
18-
from ultrack.core.export import to_ctc, to_trackmate, to_tracks_layer, tracks_to_zarr
18+
from ultrack.core.export import (
19+
to_ctc,
20+
to_geff_from_database,
21+
to_trackmate,
22+
to_tracks_layer,
23+
tracks_to_zarr,
24+
)
1925
from ultrack.core.solve.sqltracking import SQLTracking
2026
from ultrack.imgproc.measure import tracks_properties
2127
from ultrack.utils.data import validate_and_overwrite_path
@@ -187,6 +193,45 @@ def trackmate_cli(
187193
to_trackmate(config, output_path, overwrite)
188194

189195

196+
@click.command("geff")
197+
@click.argument(
198+
"database_path",
199+
type=click.Path(path_type=Path, exists=True),
200+
)
201+
@click.option(
202+
"--output-path",
203+
"-o",
204+
required=False,
205+
type=click.Path(path_type=Path),
206+
default=None,
207+
help=(
208+
"Geff (Graph Exchange File Format) output path. "
209+
"If not provided, saves to same directory as database with '_geff.geff' extension."
210+
),
211+
)
212+
@overwrite_option()
213+
def geff_cli(
214+
database_path: Path,
215+
output_path: Optional[Path],
216+
overwrite: bool,
217+
) -> None:
218+
"""
219+
Exports tracking results to Geff (Graph Exchange File Format) format.
220+
"""
221+
if output_path is None:
222+
# Generate output path from database path
223+
output_path = database_path.parent / f"{database_path.stem}_geff.geff"
224+
else:
225+
# Validate that the output path has a geff extension
226+
output_str = str(output_path)
227+
if not (output_str.endswith(".geff") or output_str.endswith(".geff.zarr")):
228+
raise click.BadParameter(
229+
f"Output path must have a .geff or .geff.zarr extension, got: {output_path}"
230+
)
231+
232+
to_geff_from_database(database_path, output_path, overwrite)
233+
234+
190235
@click.command("lp")
191236
@click.option(
192237
"--output-path",
@@ -229,6 +274,7 @@ def export_cli() -> None:
229274

230275

231276
export_cli.add_command(ctc_cli)
277+
export_cli.add_command(geff_cli)
232278
export_cli.add_command(lp_cli)
233279
export_cli.add_command(trackmate_cli)
234280
export_cli.add_command(zarr_napari_cli)

ultrack/core/export/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ultrack.core.export.ctc import to_ctc
22
from ultrack.core.export.exporter import export_tracks_by_extension
3-
from ultrack.core.export.geff import to_geff
3+
from ultrack.core.export.geff import to_geff, to_geff_from_database
44
from ultrack.core.export.networkx import to_networkx, tracks_layer_to_networkx
55
from ultrack.core.export.trackmate import to_trackmate, tracks_layer_to_trackmate
66
from ultrack.core.export.tracks_layer import to_tracks_layer

ultrack/core/export/geff.py

Lines changed: 166 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
from ultrack.core.database import NO_PARENT, LinkDB, NodeDB, OverlapDB
1616

1717

18-
def to_geff(
19-
config: MainConfig,
18+
def to_geff_from_database(
19+
database_path: Union[str, Path],
2020
filename: Union[str, Path],
2121
overwrite: bool = False,
2222
) -> None:
2323
"""
24-
Export tracks to a geff (Graph Exchange File Format) file.
24+
Export tracks to a geff (Graph Exchange File Format) file from a database.
2525
2626
Parameters
2727
----------
28-
config : MainConfig
29-
The configuration object.
28+
database_path : str or Path
29+
The path to the database file.
3030
filename : str or Path
3131
The name of the file to save the tracks to.
3232
overwrite : bool, optional
@@ -46,9 +46,39 @@ def to_geff(
4646
else:
4747
shutil.rmtree(filename)
4848

49-
engine = sqla.create_engine(config.data_config.database_path)
49+
# Convert database_path to SQLAlchemy URL format if needed
50+
database_path_str = str(database_path)
51+
# If it's not already a SQLAlchemy URL (doesn't start with a protocol), assume it's a SQLite file path
52+
if not database_path_str.startswith(
53+
("sqlite://", "postgresql://", "mysql://", "postgresql+psycopg2://")
54+
):
55+
# Convert file path to SQLite URL format
56+
database_path_str = f"sqlite:///{Path(database_path).absolute()}"
57+
engine = sqla.create_engine(database_path_str)
5058
with Session(engine) as session:
51-
node_stmt = session.query(
59+
# Collect nodes data, storing masks and bboxes separately
60+
all_nodes_data = []
61+
all_masks = []
62+
all_bboxes = []
63+
solution_source = []
64+
solution_target = []
65+
66+
for (
67+
node_id,
68+
t,
69+
parent_id,
70+
z,
71+
y,
72+
x,
73+
z_shift,
74+
y_shift,
75+
x_shift,
76+
area,
77+
frontier,
78+
height,
79+
selected,
80+
pickle_obj,
81+
) in session.query(
5282
NodeDB.id,
5383
NodeDB.t,
5484
NodeDB.parent_id,
@@ -63,59 +93,103 @@ def to_geff(
6393
NodeDB.height,
6494
NodeDB.selected,
6595
NodeDB.pickle,
66-
).statement
67-
node_df = pd.read_sql(node_stmt, session.bind, index_col="id")
68-
node_df["id"] = node_df.index
96+
):
97+
node_dict = {
98+
"id": node_id,
99+
"parent_id": parent_id,
100+
"t": t,
101+
"z": z,
102+
"y": y,
103+
"x": x,
104+
"z_shift": z_shift,
105+
"y_shift": y_shift,
106+
"x_shift": x_shift,
107+
"area": area,
108+
"frontier": frontier,
109+
"height": height,
110+
"solution": selected,
111+
}
112+
all_nodes_data.append(node_dict)
113+
# Store masks and bboxes separately
114+
all_masks.append(pickle_obj.mask.astype(np.uint64))
115+
all_bboxes.append(pickle_obj.bbox.astype(np.int64))
116+
117+
# Collect solution edges (parent-child relationships)
118+
if selected and parent_id != NO_PARENT:
119+
solution_source.append(parent_id)
120+
solution_target.append(node_id)
121+
122+
# Create nodes dataframe (only scalar values, no pickle objects)
123+
node_df = pd.DataFrame(all_nodes_data)
124+
node_df.set_index("id", inplace=True)
69125

126+
# Query edges
70127
edge_stmt = session.query(
71128
LinkDB.source_id, LinkDB.target_id, LinkDB.weight
72129
).statement
73130
edge_df = pd.read_sql(edge_stmt, session.bind)
74131

75-
sol_links_df = node_df.loc[
76-
node_df["selected"] & node_df["parent_id"] != NO_PARENT,
77-
["id", "parent_id"],
78-
]
79-
sol_links_df = sol_links_df.rename(
80-
columns={"parent_id": "source_id", "id": "target_id"},
132+
# Add solution column to edges
133+
sol_links_df = pd.DataFrame(
134+
{
135+
"source_id": solution_source,
136+
"target_id": solution_target,
137+
"solution": True,
138+
}
81139
)
82-
sol_links_df["solution"] = True
83-
edge_df = edge_df.merge(sol_links_df, on=["source_id", "target_id"])
84-
edge_df["solution"] = edge_df["solution"].fillna(False)
85-
86-
node_df.rename(columns={"selected": "solution"}, inplace=True)
87-
node_df.drop(["id", "parent_id"], axis=1, inplace=True)
140+
edge_df = edge_df.merge(sol_links_df, on=["source_id", "target_id"], how="left")
141+
edge_df.loc[edge_df["solution"].isna(), "solution"] = False
142+
edge_df["solution"] = edge_df["solution"].astype(bool)
88143

144+
# Query overlaps
89145
overlap_stmt = session.query(
90146
OverlapDB.node_id,
91147
OverlapDB.ancestor_id,
92148
).statement
93149
overlap_df = pd.read_sql(overlap_stmt, session.bind)
94150

95-
node_props_metadata = {
96-
c: PropMetadata(
151+
# Helper function to convert pandas/numpy dtypes to string dtype names
152+
def dtype_to_str(dtype) -> str:
153+
"""Convert pandas/numpy dtype to string dtype name for PropMetadata."""
154+
# Convert to numpy dtype first to get consistent .name attribute
155+
np_dtype = np.dtype(dtype)
156+
dtype_name = np_dtype.name
157+
158+
# Most dtypes work directly (int64, float64, bool, etc.)
159+
return dtype_name
160+
161+
# Create node properties metadata
162+
node_props_metadata = {}
163+
for c in node_df.columns:
164+
node_props_metadata[c] = PropMetadata(
97165
identifier=c,
98-
dtype=node_df[c].dtype,
166+
dtype=dtype_to_str(node_df[c].dtype),
99167
)
100-
for c in node_df.columns
101-
if c != "pickle"
102-
}
103168
node_props_metadata["mask"] = PropMetadata(
104169
identifier="mask",
105-
dtype=np.uint64,
170+
dtype="uint64",
106171
varlength=True,
107172
)
108173
node_props_metadata["bbox"] = PropMetadata(
109174
identifier="bbox",
110-
dtype=np.int64,
175+
dtype="int64",
111176
)
112177

113-
edge_ids = edge_df[["source_id", "target_id"]].to_numpy(dtype=np.uint64)
178+
# Prepare edge IDs and properties
179+
edge_ids = np.column_stack(
180+
[
181+
edge_df["source_id"].to_numpy(dtype=np.uint64),
182+
edge_df["target_id"].to_numpy(dtype=np.uint64),
183+
]
184+
)
114185
edge_df = edge_df.drop(columns=["source_id", "target_id"])
115186

116-
edge_props_metadata = {
117-
c: PropMetadata(identifier=c, dtype=edge_df[c].dtype) for c in edge_df.columns
118-
}
187+
# Create edge properties metadata
188+
edge_props_metadata = {}
189+
for c in edge_df.columns:
190+
edge_props_metadata[c] = PropMetadata(
191+
identifier=c, dtype=dtype_to_str(edge_df[c].dtype)
192+
)
119193

120194
geff_metadata = geff.GeffMetadata(
121195
directed=True,
@@ -129,27 +203,41 @@ def to_geff(
129203
edge_props_metadata=edge_props_metadata,
130204
)
131205

132-
node_props = {
133-
c: {"values": node_df[c].to_numpy(), "missing": None}
134-
for c in node_df.columns
135-
if c != "pickle"
136-
}
137-
node_props["mask"] = construct_var_len_props(
138-
[v.mask.astype(np.uint64) for v in node_df["pickle"]]
139-
)
140-
node_props["bbox"] = {
141-
"values": np.stack([v.bbox for v in node_df["pickle"]]),
142-
"missing": None,
143-
}
206+
# Prepare node properties (using separately stored masks and bboxes)
207+
node_props = {}
208+
for c in node_df.columns:
209+
# Convert to appropriate numpy dtype
210+
values = node_df[c].to_numpy()
211+
# Ensure bool columns are properly typed
212+
if c == "solution":
213+
values = values.astype(bool)
214+
node_props[c] = {"values": values, "missing": None}
215+
216+
# Handle mask - use the separately stored masks
217+
node_props["mask"] = construct_var_len_props(all_masks)
218+
219+
# Handle bbox - stack into 2D array from separately stored bboxes
220+
bbox_array = np.stack(all_bboxes)
221+
node_props["bbox"] = {"values": bbox_array, "missing": None}
222+
223+
# Prepare edge properties with proper dtypes
224+
edge_props = {}
225+
for c in edge_df.columns:
226+
values = edge_df[c].to_numpy()
227+
# Ensure bool columns are properly typed
228+
if c == "solution":
229+
values = values.astype(bool)
230+
# Ensure weight is float
231+
elif c == "weight":
232+
values = values.astype(np.float64)
233+
edge_props[c] = {"values": values, "missing": None}
144234

145235
write_arrays(
146236
filename,
147237
node_ids=node_df.index.to_numpy(dtype=np.uint64),
148238
node_props=node_props,
149239
edge_ids=edge_ids,
150-
edge_props={
151-
c: {"values": edge_df[c].to_numpy(), "missing": None} for c in edge_df
152-
},
240+
edge_props=edge_props,
153241
metadata=geff_metadata,
154242
)
155243

@@ -159,3 +247,32 @@ def to_geff(
159247
dtype=np.uint64
160248
)
161249
store.create_group("overlaps/props")
250+
251+
252+
def to_geff(
253+
config: MainConfig,
254+
filename: Union[str, Path],
255+
overwrite: bool = False,
256+
) -> None:
257+
"""
258+
Export tracks to a geff (Graph Exchange File Format) file.
259+
260+
Parameters
261+
----------
262+
config : MainConfig
263+
The configuration object.
264+
filename : str or Path
265+
The name of the file to save the tracks to.
266+
overwrite : bool, optional
267+
Whether to overwrite the file if it already exists, by default False.
268+
269+
Raises
270+
------
271+
FileExistsError
272+
If the file already exists and overwrite is False.
273+
"""
274+
to_geff_from_database(
275+
database_path=config.data_config.database_path,
276+
filename=filename,
277+
overwrite=overwrite,
278+
)

0 commit comments

Comments
 (0)