Skip to content
52 changes: 38 additions & 14 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import xarray as xr
from non_local_detector.models.base import ClusterlessDetector
from scipy.ndimage import label
from track_linearization import get_linearized_position

from spyglass.common.common_interval import IntervalList # noqa: F401
Expand Down Expand Up @@ -237,7 +238,9 @@ def _run_decoder(
classifier : ClusterlessDetector
Fitted classifier instance
results : xr.Dataset
Decoding results with posteriors
Decoding results with posteriors. Results from multiple intervals
are concatenated along the time dimension with an interval_labels
coordinate to track which interval each time point belongs to.
"""
classifier = ClusterlessDetector(**decoding_params)

Expand Down Expand Up @@ -266,6 +269,14 @@ def _run_decoder(
time=position_info.index.to_numpy(),
**decoding_kwargs,
)
# Add interval_labels coordinate for consistency with predict branch
# label() returns 1-indexed labels; subtract 1 for 0-indexed intervals
# Result: -1 = outside intervals, 0, 1, 2... = interval index
labels, _ = label(~is_missing)
interval_labels = labels - 1
results = results.assign_coords(
interval_labels=("time", interval_labels)
)
else:
VALID_FIT_KWARGS = [
"is_training",
Expand Down Expand Up @@ -300,7 +311,10 @@ def _run_decoder(

# We treat each decoding interval as a separate sequence
results = []
for interval_start, interval_end in decoding_interval:
interval_labels = []
for interval_idx, (interval_start, interval_end) in enumerate(
decoding_interval
):
interval_time = position_info.loc[
interval_start:interval_end
].index.to_numpy()
Expand All @@ -310,19 +324,29 @@ def _run_decoder(
f"Interval {interval_start}:{interval_end} is empty"
)
continue
results.append(
classifier.predict(
position_time=interval_time,
position=position_info.loc[interval_start:interval_end][
position_variable_names
].to_numpy(),
spike_times=spike_times,
spike_waveform_features=spike_waveform_features,
time=interval_time,
**predict_kwargs,
)
interval_result = classifier.predict(
position_time=interval_time,
position=position_info.loc[interval_start:interval_end][
position_variable_names
].to_numpy(),
spike_times=spike_times,
spike_waveform_features=spike_waveform_features,
time=interval_time,
**predict_kwargs,
)
results = xr.concat(results, dim="intervals")
results.append(interval_result)
# Track which interval each time point belongs to
interval_labels.extend(
[interval_idx] * len(interval_result.time)
)
# Concatenate along time dimension instead of intervals dimension
if not results:
raise ValueError("All decoding intervals are empty")
results = xr.concat(results, dim="time")
# Add interval_labels as a coordinate for groupby/selection operations
results = results.assign_coords(
interval_labels=("time", interval_labels)
)

# Save discrete transition and initial conditions
results["initial_conditions"] = xr.DataArray(
Expand Down
50 changes: 37 additions & 13 deletions src/spyglass/decoding/v1/sorted_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from scipy.ndimage import label
from non_local_detector.models.base import SortedSpikesDetector
from track_linearization import get_linearized_position

Expand Down Expand Up @@ -195,7 +196,9 @@ def _run_decoder(
classifier : SortedSpikesDetector
Fitted classifier instance
results : xr.Dataset
Decoding results with posteriors
Decoding results with posteriors. Results from multiple intervals
are concatenated along the time dimension with an interval_labels
coordinate to track which interval each time point belongs to.
"""
classifier = SortedSpikesDetector(**decoding_params)

Expand Down Expand Up @@ -223,6 +226,14 @@ def _run_decoder(
time=position_info.index.to_numpy(),
**decoding_kwargs,
)
# Add interval_labels coordinate for consistency with predict branch
# label() returns 1-indexed labels; subtract 1 for 0-indexed intervals
# Result: -1 = outside intervals, 0, 1, 2... = interval index
labels, _ = label(~is_missing)
interval_labels = labels - 1
results = results.assign_coords(
interval_labels=("time", interval_labels)
)
else:
VALID_FIT_KWARGS = [
"is_training",
Expand Down Expand Up @@ -255,7 +266,10 @@ def _run_decoder(

# We treat each decoding interval as a separate sequence
results = []
for interval_start, interval_end in decoding_interval:
interval_labels = []
for interval_idx, (interval_start, interval_end) in enumerate(
decoding_interval
):
interval_time = position_info.loc[
interval_start:interval_end
].index.to_numpy()
Expand All @@ -265,18 +279,28 @@ def _run_decoder(
f"Interval {interval_start}:{interval_end} is empty"
)
continue
results.append(
classifier.predict(
position_time=interval_time,
position=position_info.loc[interval_start:interval_end][
position_variable_names
].to_numpy(),
spike_times=spike_times,
time=interval_time,
**predict_kwargs,
)
interval_result = classifier.predict(
position_time=interval_time,
position=position_info.loc[interval_start:interval_end][
position_variable_names
].to_numpy(),
spike_times=spike_times,
time=interval_time,
**predict_kwargs,
)
results = xr.concat(results, dim="intervals")
results.append(interval_result)
# Track which interval each time point belongs to
interval_labels.extend(
[interval_idx] * len(interval_result.time)
)
# Concatenate along time dimension instead of intervals dimension
if not results:
raise ValueError("All decoding intervals are empty")
results = xr.concat(results, dim="time")
# Add interval_labels as a coordinate for groupby/selection operations
results = results.assign_coords(
interval_labels=("time", interval_labels)
)

# Save discrete transition and initial conditions
results["initial_conditions"] = xr.DataArray(
Expand Down
Loading
Loading