Skip to content

Commit 00c2c9e

Browse files
committed
Refactor trackpy exporter
1 parent 286fe34 commit 00c2c9e

1 file changed

Lines changed: 154 additions & 53 deletions

File tree

pycellin/io/trackpy/exporter.py

Lines changed: 154 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,100 +24,200 @@
2424
from pycellin.classes.model import Model
2525

2626

27-
def export_trackpy_dataframe(model: Model) -> pd.DataFrame:
27+
def safekeep_original_lineage_IDs(model: Model) -> None:
2828
"""
29-
Export a Pycellin model to a trackpy DataFrame.
29+
Add original lineage IDs to the nodes of the model.
3030
31-
Trackpy does not support division events. They will be removed for
32-
the export so each cell cycle will be reprensented by a single
33-
trackpy track in the dataframe.
31+
We want to safekeep them since we are going to renumber
32+
the lineages later on.
3433
3534
Parameters
3635
----------
3736
model : Model
38-
The Pycellin model to export.
39-
40-
Returns
41-
-------
42-
pd.DataFrame
43-
A DataFrame containing trackpy formatted data.
37+
The Pycellin model to modify.
4438
"""
45-
model_copy = copy.deepcopy(model) # Don't want to modify the original model.
46-
47-
# We want to safekeep the original lineage IDs in the nodes of the model since
48-
# we are going to rename and/or renumber them.
49-
for lin_ID, lin in model_copy.data.cell_data.items():
39+
for lin_ID, lin in model.data.cell_data.items():
5040
for node in lin.nodes():
5141
lin.nodes[node]["lineage_ID_Pycellin"] = lin_ID
5242

53-
# Removal of division events.
54-
# We simply remove the edges involved in the divisions.
55-
for lin in model_copy.get_cell_lineages():
43+
44+
def remove_division_events(model: Model) -> None:
45+
"""
46+
Remove division events by deleting edges involved in divisions.
47+
48+
Parameters
49+
----------
50+
model : Model
51+
The Pycellin model to modify.
52+
"""
53+
for lin in model.get_cell_lineages():
5654
divs = lin.get_divisions()
5755
div_edges = [edge for div in divs for edge in lin.out_edges(div)]
5856
for edge in div_edges:
59-
model_copy.remove_link(*edge, lin.graph["lineage_ID"])
60-
model_copy.update()
57+
model.remove_link(*edge, lin.graph["lineage_ID"])
58+
model.update()
59+
6160

62-
# Trackpy might not like negative lineage IDs so we change them to positive ones.
61+
def renumber_negative_lineage_IDs(model: Model) -> None:
62+
"""
63+
Ensure lineage IDs are positive.
64+
65+
Trackpy might not support negative lineage IDs so it is safer to
66+
renumber them to positive ones.
67+
68+
Parameters
69+
----------
70+
model : Model
71+
The Pycellin model to modify.
72+
"""
6373
one_node_lin_IDs = [
6474
lin.graph["lineage_ID"]
65-
for lin in model_copy.get_cell_lineages()
75+
for lin in model.get_cell_lineages()
6676
if lin.graph["lineage_ID"] < 0
6777
]
6878
for lin_ID in one_node_lin_IDs:
69-
lin = model_copy.get_cell_lineage_from_ID(lin_ID)
79+
lin = model.get_cell_lineage_from_ID(lin_ID)
7080
assert lin is not None
71-
new_lin_ID = model_copy.get_next_available_lineage_ID()
81+
new_lin_ID = model.get_next_available_lineage_ID()
7282
# Update the lineage ID in the graph.
7383
lin.graph["lineage_ID"] = new_lin_ID
7484
# Update the lineage ID in the cell data.
75-
model_copy.data.cell_data.pop(lin_ID)
76-
model_copy.data.cell_data[new_lin_ID] = lin
85+
model.data.cell_data.pop(lin_ID)
86+
model.data.cell_data[new_lin_ID] = lin
7787

78-
# Creation of the trackpy DataFrame.
79-
df = model_copy.to_cell_dataframe()
80-
# We have to rename some columns to be compatible with trackpy.
81-
if "particle" in df.columns:
82-
# If we already have this column, it means the data is coming from
83-
# trackpy, but it might not be up to date. Safer to remove it and
84-
# rename it from "lineage_ID".
85-
df.drop(columns=["particle"], inplace=True)
86-
df.rename(columns={"lineage_ID": "particle"}, inplace=True)
87-
df.rename(columns={"cell_x": "x"}, inplace=True)
88-
df.rename(columns={"cell_y": "y"}, inplace=True)
89-
if "cell_z" in df.columns:
90-
df.rename(columns={"cell_z": "z"}, inplace=True)
91-
if "ROI_coords" in df.columns:
92-
# We need to remove the ROI_coords column.
93-
df.drop(columns=["ROI_coords"], inplace=True)
94-
# Reorder the columns to match trackpy format.
95-
if "z" in df.columns:
96-
dim_columns = ["z", "y", "x"]
97-
else:
98-
dim_columns = ["y", "x"]
88+
89+
def rename_columns_if_exist(df, columns_map):
90+
"""
91+
Helper function to rename columns if they exist in the DataFrame.
92+
93+
Parameters
94+
----------
95+
df : pd.DataFrame
96+
The DataFrame to modify.
97+
columns_map : dict
98+
A dictionary mapping old column names to new column names.
99+
"""
100+
for old_name, new_name in columns_map.items():
101+
if old_name in df.columns:
102+
df.rename(columns={old_name: new_name}, inplace=True)
103+
104+
105+
def drop_columns_if_exist(df, columns):
106+
"""
107+
Helper function to drop columns if they exist in the DataFrame.
108+
109+
Parameters
110+
----------
111+
df : pd.DataFrame
112+
The DataFrame to modify.
113+
columns : list
114+
The names of the columns to drop.
115+
"""
116+
for column in columns:
117+
if column in df.columns:
118+
df.drop(columns=[column], inplace=True)
119+
120+
121+
def format_dataframe(df: pd.DataFrame) -> pd.DataFrame:
122+
"""
123+
Format the DataFrame to be compatible with trackpy.
124+
125+
Parameters
126+
----------
127+
df : pd.DataFrame
128+
The DataFrame to format.
129+
130+
Returns
131+
-------
132+
pd.DataFrame
133+
The formatted DataFrame.
134+
"""
135+
# Drop unnecessary columns.
136+
drop_columns_if_exist(df, ["ROI_coords", "particle"])
137+
# If we already have the "particle" column, it means the data is coming from
138+
# trackpy, but it might not be up to date. Safer to remove it then recreate
139+
# it from "lineage_ID".
140+
141+
# Rename columns to match trackpy format.
142+
rename_columns_if_exist(
143+
df,
144+
{
145+
"cell_x": "x",
146+
"cell_y": "y",
147+
"cell_z": "z",
148+
"lineage_ID": "particle",
149+
},
150+
)
151+
152+
# Reorder columns to match trackpy format
153+
dim_columns = ["z", "y", "x"] if "z" in df.columns else ["y", "x"]
99154
df = df[
100155
dim_columns
101-
+ [col for col in df.columns if col not in ["z", "y", "x", "frame", "particle"]]
156+
+ [col for col in df.columns if col not in dim_columns + ["frame", "particle"]]
102157
+ ["frame", "particle"]
103158
]
159+
104160
# Sort the rows.
105161
df.sort_values(by=["particle", "frame"], inplace=True)
106162

107163
return df
108164

109165

166+
def export_trackpy_dataframe(model: Model) -> pd.DataFrame:
167+
"""
168+
Export a Pycellin model to a trackpy DataFrame.
169+
170+
Trackpy does not support division events. They will be removed for
171+
the export so each cell cycle will be reprensented by a single
172+
trackpy track in the dataframe.
173+
174+
Parameters
175+
----------
176+
model : Model
177+
The Pycellin model to export.
178+
179+
Returns
180+
-------
181+
pd.DataFrame
182+
A DataFrame containing trackpy formatted data.
183+
"""
184+
# Prepare the model for export.
185+
model_copy = copy.deepcopy(model) # Don't want to modify the original model.
186+
safekeep_original_lineage_IDs(model_copy)
187+
remove_division_events(model_copy) # Trackpy does not support division events.
188+
renumber_negative_lineage_IDs(model_copy)
189+
190+
# Creation of the trackpy DataFrame.
191+
df = model_copy.to_cell_dataframe()
192+
df = format_dataframe(df)
193+
194+
return df
195+
196+
110197
if __name__ == "__main__":
198+
199+
# # Test with a sample TrackMate XML file.
200+
# from pycellin import load_TrackMate_XML
201+
202+
# xml = "sample_data/Ecoli_growth_on_agar_pad.xml"
203+
204+
# model = load_TrackMate_XML(xml)
205+
# for lin in model.get_cell_lineages():
206+
# print(lin)
207+
208+
# df = export_trackpy_dataframe(model)
209+
# print(df.head())
210+
211+
# Test with a sample trackpy DataFrame.
212+
from pycellin import load_trackpy_dataframe
213+
111214
folder = "/mnt/data/Code/trackpy-examples-master/sample_data/"
112215
tracks = "FakeTracks_trackpy.pkl"
113-
xml = "sample_data/Ecoli_growth_on_agar_pad.xml"
114216

115217
df = pd.read_pickle(folder + tracks)
116218
print(df.head())
219+
print(df.shape)
117220

118-
from pycellin import load_trackpy_dataframe, load_TrackMate_XML
119-
120-
# model = load_TrackMate_XML(xml)
121221
model = load_trackpy_dataframe(df)
122222
for lin in model.get_cell_lineages():
123223
print(lin)
@@ -128,3 +228,4 @@ def export_trackpy_dataframe(model: Model) -> pd.DataFrame:
128228

129229
df = export_trackpy_dataframe(model)
130230
print(df.head())
231+
print(df.shape)

0 commit comments

Comments
 (0)