Skip to content

Commit a2dff1d

Browse files
authored
Merge pull request #31 from Image-Analysis-Hub/fix/TM_loader
Fix/tm loader
2 parents 5ae344a + 8cdf897 commit a2dff1d

3 files changed

Lines changed: 60 additions & 64 deletions

File tree

pycellin/classes/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2015,14 +2015,19 @@ def save_to_pickle(
20152015
pickle.dump(self, file, protocol=protocol)
20162016

20172017
@staticmethod
2018-
def load_from_pickle(path: str) -> None:
2018+
def load_from_pickle(path: str) -> "Model":
20192019
"""
20202020
Load a model from a pickled pycellin file.
20212021
20222022
Parameters
20232023
----------
20242024
path : str
20252025
Path to read the model.
2026+
2027+
Returns
2028+
-------
2029+
Model
2030+
The loaded model.
20262031
"""
20272032
with open(path, "rb") as file:
20282033
return pickle.load(file)

pycellin/io/trackmate/loader.py

Lines changed: 49 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from copy import deepcopy
55
from datetime import datetime
6-
import importlib
6+
import importlib.metadata
77
from pathlib import Path
88
from typing import Any
99
import warnings
@@ -15,6 +15,7 @@
1515
from pycellin.classes import FeaturesDeclaration, Feature, cell_ID_Feature
1616
from pycellin.classes import Data
1717
from pycellin.classes import CellLineage
18+
from pycellin.custom_types import FeatureType
1819

1920

2021
def _get_units(
@@ -238,7 +239,7 @@ def _add_all_features(
238239
def _convert_attributes(
239240
attributes: dict[str, str],
240241
features: dict[str, Feature],
241-
feature_type: str,
242+
feature_type: FeatureType,
242243
) -> None:
243244
"""
244245
Convert the values of `attributes` from string to the correct data type.
@@ -253,7 +254,7 @@ def _convert_attributes(
253254
features : dict[str, Feature]
254255
The dictionary of features that contains the information on how to convert
255256
the values of `attributes`.
256-
feature_type : str
257+
feature_type : FeatureType
257258
The type of the feature to convert (node, edge, or lineage).
258259
259260
Raises
@@ -1058,19 +1059,11 @@ def _get_specific_tags(
10581059
was found in the XML file, and the corresponding value is the
10591060
deep copied `ET._Element` object for that tag.
10601061
"""
1061-
it = ET.iterparse(xml_path, events=["start", "end"])
10621062
dict_tags = {}
1063-
for event, element in it:
1064-
if event == "start" and element.tag in tag_names:
1063+
for tag in tag_names:
1064+
it = ET.iterparse(xml_path, tag=tag)
1065+
for _, element in it:
10651066
dict_tags[element.tag] = deepcopy(element)
1066-
tag_names.remove(element.tag)
1067-
if not tag_names:
1068-
# All the tags have been found.
1069-
break
1070-
1071-
if event == "end":
1072-
element.clear()
1073-
10741067
return dict_tags
10751068

10761069

@@ -1091,11 +1084,10 @@ def _get_trackmate_version(
10911084
The version of TrackMate used to generate the XML file. If the
10921085
version cannot be found, "unknown" is returned.
10931086
"""
1094-
it = ET.iterparse(xml_path, events=["start", "end"])
1095-
for event, element in it:
1096-
if event == "start" and element.tag == "TrackMate":
1097-
version = str(element.attrib["version"])
1098-
return version
1087+
it = ET.iterparse(xml_path, tag="TrackMate")
1088+
for _, element in it:
1089+
version = str(element.attrib["version"])
1090+
return version
10991091
return "unknown"
11001092

11011093

@@ -1120,19 +1112,17 @@ def _get_time_step(settings: ET._Element) -> float:
11201112
KeyError
11211113
If the 'ImageData' element is not found in the settings.
11221114
"""
1123-
for element in settings.iterchildren():
1124-
if element.tag == "ImageData":
1125-
try:
1126-
return float(element.attrib["timeinterval"])
1127-
except KeyError:
1128-
raise KeyError(
1129-
"The 'timeinterval' attribute is missing "
1130-
"in the 'ImageData' element."
1131-
)
1132-
except ValueError:
1133-
raise ValueError(
1134-
"The 'timeinterval' attribute cannot be converted to float."
1135-
)
1115+
for element in settings.iterchildren("ImageData"):
1116+
try:
1117+
return float(element.attrib["timeinterval"])
1118+
except KeyError:
1119+
raise KeyError(
1120+
"The 'timeinterval' attribute is missing in the 'ImageData' element."
1121+
)
1122+
except ValueError:
1123+
raise ValueError(
1124+
"The 'timeinterval' attribute cannot be converted to float."
1125+
)
11361126

11371127
raise KeyError("The 'ImageData' element is not found in the settings.")
11381128

@@ -1160,24 +1150,23 @@ def _get_pixel_size(settings: ET._Element) -> dict[str, float]:
11601150
If the 'pixelwidth', 'pixelheight' or 'voxeldepth' attribute is missing,
11611151
or if the 'ImageData' element is not found in the settings.
11621152
"""
1163-
for element in settings.iterchildren():
1164-
if element.tag == "ImageData":
1165-
pixel_size = {}
1166-
for key_TM, key_pycellin in zip(
1167-
["pixelwidth", "pixelheight", "voxeldepth"],
1168-
["width", "height", "depth"],
1169-
):
1170-
try:
1171-
pixel_size[key_pycellin] = float(element.attrib[key_TM])
1172-
except KeyError:
1173-
raise KeyError(
1174-
f"The {key_TM} attribute is missing in the 'ImageData' element."
1175-
)
1176-
except ValueError:
1177-
raise ValueError(
1178-
f"The {key_TM} attribute cannot be converted to float."
1179-
)
1180-
return pixel_size
1153+
for element in settings.iterchildren("ImageData"):
1154+
pixel_size = {}
1155+
for key_TM, key_pycellin in zip(
1156+
["pixelwidth", "pixelheight", "voxeldepth"],
1157+
["width", "height", "depth"],
1158+
):
1159+
try:
1160+
pixel_size[key_pycellin] = float(element.attrib[key_TM])
1161+
except KeyError:
1162+
raise KeyError(
1163+
f"The {key_TM} attribute is missing " "in the 'ImageData' element."
1164+
)
1165+
except ValueError:
1166+
raise ValueError(
1167+
f"The {key_TM} attribute cannot be converted to float."
1168+
)
1169+
return pixel_size
11811170

11821171
raise KeyError("The 'ImageData' element is not found in the settings.")
11831172

@@ -1229,7 +1218,7 @@ def load_TrackMate_XML(
12291218
version = importlib.metadata.version("pycellin")
12301219
except importlib.metadata.PackageNotFoundError:
12311220
version = "unknown"
1232-
metadata["Pycellin_version"] = version
1221+
metadata["pycellin_version"] = version
12331222
metadata["TrackMate_version"] = _get_trackmate_version(xml_path)
12341223
dict_tags = _get_specific_tags(
12351224
xml_path, ["Log", "Settings", "GUIState", "DisplaySettings"]
@@ -1264,19 +1253,17 @@ def load_TrackMate_XML(
12641253
# xml = "sample_data/Ecoli_growth_on_agar_pad_with_fusions.xml"
12651254
xml = "sample_data/Celegans-5pc-17timepoints.xml"
12661255

1267-
xml = "C:/Users/lxenard/Documents/Code/mastodon_data_bug_edge_pycellin/test-export2.xml"
1268-
12691256
model = load_TrackMate_XML(xml) # , keep_all_spots=True, keep_all_tracks=True)
12701257
print(model)
1271-
print(model.get_fusions())
1272-
1273-
# model = load_TrackMate_XML(xml, keep_all_spots=True, keep_all_tracks=True)
1274-
# print(model)
1275-
# print(model.feat_declaration)
1276-
# print(model.metadata["Pycellin_version"])
1277-
# # print(model.metadata)
1278-
# # print(model.fdec.node_feats.keys())
1279-
# # print(model.data)
1258+
1259+
print(model.feat_declaration)
1260+
print(model.metadata["pycellin_version"])
1261+
print(model.metadata)
1262+
# print(model.fdec.node_feats.keys())
1263+
# print(model.data)
1264+
1265+
# lineage = model.data.cell_data[0]
1266+
# lineage.plot(node_hover_features=["cell_ID", "cell_name"])
12801267

12811268
# lineage = model.data.cell_data[0]
12821269
# lineage.plot(node_hover_features=["cell_ID", "cell_name"])

tests/io/trackmate/test_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,12 +1159,16 @@ def test_update_location_related_features_one_node():
11591159

11601160
def test_get_specific_tags():
11611161
xml_path = "sample_data/FakeTracks.xml"
1162-
tag_names = ["GUIState", "FeaturePenalties"]
1162+
tag_names = ["GUIState", "FeaturePenalties", "FilteredTracks"]
11631163
obtained = tml._get_specific_tags(xml_path, tag_names)
11641164

1165+
nested_element = ET.Element("FilteredTracks")
1166+
nested_element.append(ET.Element("TrackID", attrib={"TRACK_ID": "0"}))
1167+
nested_element.append(ET.Element("TrackID", attrib={"TRACK_ID": "4"}))
11651168
expected = {
11661169
"GUIState": ET.Element("GUIState", attrib={"state": "ConfigureViews"}),
11671170
"FeaturePenalties": ET.Element("FeaturePenalties"), # empty tag
1171+
"FilteredTracks": nested_element,
11681172
}
11691173

11701174
assert obtained.keys() == expected.keys()

0 commit comments

Comments
 (0)