Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 85 additions & 34 deletions sleap/gui/learning/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sleap.gui.dialogs.formbuilder import FieldComboWidget
from omegaconf import OmegaConf
from sleap.util import show_sleap_nn_installation_message
from sleap.gui.learning.load_legacy_metrics import load_npz_extract_arrays

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -148,14 +149,38 @@ def validation_frame_count(self):
@property
def timestamp(self):
"""Timestamp on file; parsed from filename (not OS timestamp)."""
match = re.match(
r".*?(?<!\d)(\d{2})(\d{2})(\d{2})_(\d{2})(\d{2})(\d{2})\b",
self.config.trainer_config.run_name,
)
if match:
year, month, day = int(match[1]), int(match[2]), int(match[3])
hour, minute, sec = int(match[4]), int(match[5]), int(match[6])
return datetime.datetime(2000 + year, month, day, hour, minute, sec)
timestamp_pattern = r".*?(?<!\d)(\d{2})(\d{2})(\d{2})_(\d{2})(\d{2})(\d{2})\b"

# Try to get run_name from config (handles both sleap-nn and legacy formats)
run_name = None
try:
# sleap-nn format
run_name = self.config.trainer_config.run_name
except (AttributeError, TypeError):
try:
# Legacy SLEAP format (OmegaConf or dict)
if hasattr(self.config, "outputs"):
run_name = self.config.outputs.run_name
elif isinstance(self.config, dict):
run_name = self.config.get("outputs", {}).get("run_name")
except (AttributeError, TypeError):
pass

# Try matching run_name first
if run_name:
match = re.match(timestamp_pattern, run_name)
if match:
year, month, day = int(match[1]), int(match[2]), int(match[3])
hour, minute, sec = int(match[4]), int(match[5]), int(match[6])
return datetime.datetime(2000 + year, month, day, hour, minute, sec)

# Fallback to parsing from path if run_name doesn't have timestamp
if self.path:
match = re.match(timestamp_pattern, self.path)
if match:
year, month, day = int(match[1]), int(match[2]), int(match[3])
hour, minute, sec = int(match[4]), int(match[5]), int(match[6])
return datetime.datetime(2000 + year, month, day, hour, minute, sec)

return None

Expand All @@ -180,35 +205,61 @@ def _get_metrics(self, split_name: Text):
metrics_path_nn = self._get_file_path(f"{split_name}_0_pred_metrics.npz")

if metrics_path_nn is None:
# Loading legacy metrics from SLEAP <= v1.4.1
metrics_path = self._get_file_path(f"metrics.{split_name}.npz")
if metrics_path is not None:
metric_data = load_npz_extract_arrays(metrics_path)
return_dict = {
"vis.tp": metric_data.get("metrics[0].vis.tp").item(),
"vis.fp": metric_data.get("metrics[0].vis.fp").item(),
"vis.tn": metric_data.get("metrics[0].vis.tn").item(),
"vis.fn": metric_data.get("metrics[0].vis.fn").item(),
"vis.precision": metric_data.get("metrics[0].vis.precision").item(),
"vis.recall": metric_data.get("metrics[0].vis.recall").item(),
"dist.dists": metric_data.get("metrics[0].dist.dists"),
"dist.avg": metric_data.get("metrics[0].dist.avg").item(),
"dist.p50": metric_data.get("metrics[0].dist.p50").item(),
"dist.p75": metric_data.get("metrics[0].dist.p75").item(),
"dist.p90": metric_data.get("metrics[0].dist.p90").item(),
"dist.p95": metric_data.get("metrics[0].dist.p95").item(),
"dist.p99": metric_data.get("metrics[0].dist.p99").item(),
"pck.mPCK": metric_data.get("metrics[0].pck.mPCK").item(),
"oks.mOKS": metric_data.get("metrics[0].oks.mOKS").item(),
"oks_voc.mAP": metric_data.get("metrics[0].oks_voc.mAP").item(),
"oks_voc.mAR": metric_data.get("metrics[0].oks_voc.mAR").item(),
"pck_voc.mAP": metric_data.get("metrics[0].pck_voc.mAP").item(),
"pck_voc.mAR": metric_data.get("metrics[0].pck_voc.mAR").item(),
}
return return_dict

else:
metrics_path = metrics_path_nn

with np.load(metrics_path, allow_pickle=True) as data:
metric_data = data["metrics"].item()

return_dict = {
"vis.tp": metric_data["visibility_metrics"].get("tp"),
"vis.fp": metric_data["visibility_metrics"].get("fp"),
"vis.tn": metric_data["visibility_metrics"].get("tn"),
"vis.fn": metric_data["visibility_metrics"].get("fn"),
"vis.precision": metric_data["visibility_metrics"].get("precision"),
"vis.recall": metric_data["visibility_metrics"].get("recall"),
"dist.dists": metric_data["distance_metrics"].get("dists"),
"dist.avg": metric_data["distance_metrics"].get("avg"),
"dist.p50": metric_data["distance_metrics"].get("p50"),
"dist.p75": metric_data["distance_metrics"].get("p75"),
"dist.p90": metric_data["distance_metrics"].get("p90"),
"dist.p95": metric_data["distance_metrics"].get("p95"),
"dist.p99": metric_data["distance_metrics"].get("p99"),
"pck.mPCK": metric_data["pck_metrics"].get("mPCK"),
"oks.mOKS": metric_data["mOKS"].get("mOKS"),
"oks_voc.mAP": metric_data["voc_metrics"].get("oks_voc.mAP"),
"oks_voc.mAR": metric_data["voc_metrics"].get("oks_voc.mAR"),
"pck_voc.mAP": metric_data["voc_metrics"].get("pck_voc.mAP"),
"pck_voc.mAR": metric_data["voc_metrics"].get("pck_voc.mAR"),
}
return return_dict
with np.load(metrics_path, allow_pickle=True) as data:
metric_data = data["metrics"].item()

return_dict = {
"vis.tp": metric_data["visibility_metrics"].get("tp"),
"vis.fp": metric_data["visibility_metrics"].get("fp"),
"vis.tn": metric_data["visibility_metrics"].get("tn"),
"vis.fn": metric_data["visibility_metrics"].get("fn"),
"vis.precision": metric_data["visibility_metrics"].get("precision"),
"vis.recall": metric_data["visibility_metrics"].get("recall"),
"dist.dists": metric_data["distance_metrics"].get("dists"),
"dist.avg": metric_data["distance_metrics"].get("avg"),
"dist.p50": metric_data["distance_metrics"].get("p50"),
"dist.p75": metric_data["distance_metrics"].get("p75"),
"dist.p90": metric_data["distance_metrics"].get("p90"),
"dist.p95": metric_data["distance_metrics"].get("p95"),
"dist.p99": metric_data["distance_metrics"].get("p99"),
"pck.mPCK": metric_data["pck_metrics"].get("mPCK"),
"oks.mOKS": metric_data["mOKS"].get("mOKS"),
"oks_voc.mAP": metric_data["voc_metrics"].get("oks_voc.mAP"),
"oks_voc.mAR": metric_data["voc_metrics"].get("oks_voc.mAR"),
"pck_voc.mAP": metric_data["voc_metrics"].get("pck_voc.mAP"),
"pck_voc.mAR": metric_data["voc_metrics"].get("pck_voc.mAR"),
}
return return_dict

@classmethod
def from_config_file(cls, path: Text) -> "ConfigFileInfo":
Expand Down Expand Up @@ -572,7 +623,7 @@ def try_loading_path(self, path: Text) -> Optional[ConfigFileInfo]:
return None
except Exception as e:
# Couldn't load so just ignore
print(f"Couldn't load config: {e}")
print(f"Couldn't load config from `{path}`: {e}")
pass
else:
# Get the head from the model (i.e., what the model will predict)
Expand Down
129 changes: 129 additions & 0 deletions sleap/gui/learning/load_legacy_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Load legacy metrics from SLEAP <= v1.4.1"""

import numpy as np
import pickle


class MockSLEAPArray(np.ndarray):
"""
Mock ndarray subclass to replace SLEAP array objects during unpickling.

SLEAP's PointArray is an ndarray subclass, so we need this to be one too
for numpy's unpickling to work properly.
"""

def __new__(cls, shape=(0,), dtype=float):
return np.ndarray.__new__(cls, shape, dtype)

def __array_finalize__(self, obj):
pass # Required for ndarray subclasses


class MockSLEAPObject:
"""Mock class for non-array SLEAP objects."""

pass


class CustomUnpickler(pickle.Unpickler):
"""
Custom unpickler that intercepts SLEAP class loading.

When pickle tries to import sleap.instance.PointArray (which we don't have),
we return our mock class instead so unpickling can complete.
"""

def find_class(self, module, name):
# Replace SLEAP classes with our mocks
if "sleap" in module.lower():
return MockSLEAPArray if "array" in name.lower() else MockSLEAPObject

# Everything else loads normally
return super().find_class(module, name)


def extract_arrays_from_object(obj, prefix="", arrays=None):
"""
Recursively find and extract all numpy arrays and numeric values from an object.

Searches through dicts, lists, object attributes, and nested arrays
to find all numpy arrays and scalar numeric values.
"""
if arrays is None:
arrays = {}

# Handle numpy arrays
if isinstance(obj, np.ndarray):
if obj.dtype != np.dtype("object"):
# Found a data array - store it
arrays[prefix or "array"] = obj
else:
# Object array - recurse into each element
for i, item in enumerate(obj.flat):
extract_arrays_from_object(item, f"{prefix}[{i}]", arrays)

# Handle scalar numeric values (int, float, np.number)
elif isinstance(obj, (int, float, np.number)):
arrays[prefix] = obj

# Handle dictionaries - recurse into values
elif isinstance(obj, dict):
for key, val in obj.items():
new_prefix = f"{prefix}.{key}" if prefix else key
extract_arrays_from_object(val, new_prefix, arrays)

# Handle lists and tuples - recurse into elements
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
extract_arrays_from_object(item, f"{prefix}[{i}]", arrays)

# Handle objects with attributes - recurse into __dict__
elif hasattr(obj, "__dict__"):
for key, val in obj.__dict__.items():
if not key.startswith("_"): # Skip private attributes
new_prefix = f"{prefix}.{key}" if prefix else key
extract_arrays_from_object(val, new_prefix, arrays)

return arrays


def load_npz_extract_arrays(npz_file):
"""
Load .npz file and extract all numpy arrays, even from pickled objects.

Returns a dict mapping array names to numpy arrays and numeric values.
"""
import io
import zipfile
import numpy.lib.format as fmt

all_arrays = {}

# Open the .npz file (which is actually a zip archive)
with zipfile.ZipFile(npz_file, "r") as zf:
for filename in zf.namelist():
key = filename.replace(".npy", "")

# Read the .npy file and check its dtype
bio = io.BytesIO(zf.read(filename))
version = fmt.read_magic(bio)

# Use version-specific header reader (NumPy 2.x removed private functions)
if version == (1, 0):
shape, _, dtype = fmt.read_array_header_1_0(bio)
elif version == (2, 0):
shape, _, dtype = fmt.read_array_header_2_0(bio)
else:
raise ValueError(f"Unsupported .npy format version: {version}")

if dtype == np.dtype("object"):
# Pickled object - use our custom unpickler and extract arrays
obj = CustomUnpickler(bio).load()
extracted = extract_arrays_from_object(obj, prefix=key)
all_arrays.update(extracted)

else:
# Regular numeric array - load directly
bio.seek(0)
all_arrays[key] = np.load(bio, allow_pickle=False)
return all_arrays
28 changes: 15 additions & 13 deletions sleap/info/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sleap.sleap_io_adaptors.lf_labels_utils import get_labeled_frame_count
from sleap.sleap_io_adaptors.lf_labels_utils import labels_load_file
from sleap.util import show_sleap_nn_installation_message
from sleap.gui.learning.load_legacy_metrics import load_npz_extract_arrays


def describe_labels(data_path, verbose=False):
Expand Down Expand Up @@ -119,20 +120,21 @@ def rel_path(x):
def describe_metrics(metrics, legacy):
if legacy:
if isinstance(metrics, str):
metrics = np.load(metrics, allow_pickle=True)["metrics"].tolist()
metrics = load_npz_extract_arrays(metrics)

print(
f"Dist (90%/95%/99%): {metrics['dist.p90']} / {metrics['dist.p95']} / "
f"{metrics['dist.p99']}"
)
print(
f"OKS VOC (mAP / mAR): {metrics['oks_voc.mAP']} / "
f"{metrics['oks_voc.mAR']}"
)
print(
f"PCK (mean {metrics['pck.thresholds'][0]}-"
f"{metrics['pck.thresholds'][-1]} px): {metrics['pck.mPCK']}"
)
p90 = metrics.get("metrics[0].dist.p90").item()
p95 = metrics.get("metrics[0].dist.p95").item()
p99 = metrics.get("metrics[0].dist.p99").item()
print(f"Dist (90%/95%/99%): {p90} / {p95} / {p99}")

oks_map = metrics.get("metrics[0].oks_voc.mAP").item()
oks_mar = metrics.get("metrics[0].oks_voc.mAR").item()
print(f"OKS VOC (mAP / mAR): {oks_map} / {oks_mar}")

pck_min = metrics.get("metrics[0].pck.thresholds")[0]
pck_max = metrics.get("metrics[0].pck.thresholds")[-1]
mpck = metrics.get("metrics[0].pck.mPCK").item()
print(f"PCK (mean {pck_min}-{pck_max} px): {mpck}")
else:
if isinstance(metrics, str):
with np.load(metrics, allow_pickle=True) as data:
Expand Down