diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 3c753268c..b477bb253 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -19,6 +19,7 @@ import pandas as pd import xarray as xr from non_local_detector.models.base import ClusterlessDetector +from scipy.ndimage import label from track_linearization import get_linearized_position from spyglass.common.common_interval import IntervalList # noqa: F401 @@ -237,7 +238,9 @@ def _run_decoder( classifier : ClusterlessDetector Fitted classifier instance results : xr.Dataset - Decoding results with posteriors + Decoding results with posteriors. Results from multiple intervals + are concatenated along the time dimension with an interval_labels + coordinate to track which interval each time point belongs to. """ classifier = ClusterlessDetector(**decoding_params) @@ -266,6 +269,14 @@ def _run_decoder( time=position_info.index.to_numpy(), **decoding_kwargs, ) + # Add interval_labels coordinate for consistency with predict branch + # label() returns 1-indexed labels; subtract 1 for 0-indexed intervals + # Result: -1 = outside intervals, 0, 1, 2... = interval index + labels, _ = label(~is_missing) + interval_labels = labels - 1 + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) else: VALID_FIT_KWARGS = [ "is_training", @@ -300,7 +311,10 @@ def _run_decoder( # We treat each decoding interval as a separate sequence results = [] - for interval_start, interval_end in decoding_interval: + interval_labels = [] + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() @@ -310,19 +324,29 @@ def _run_decoder( f"Interval {interval_start}:{interval_end} is empty" ) continue - results.append( - classifier.predict( - position_time=interval_time, - position=position_info.loc[interval_start:interval_end][ - position_variable_names - ].to_numpy(), - spike_times=spike_times, - spike_waveform_features=spike_waveform_features, - time=interval_time, - **predict_kwargs, - ) + interval_result = classifier.predict( + position_time=interval_time, + position=position_info.loc[interval_start:interval_end][ + position_variable_names + ].to_numpy(), + spike_times=spike_times, + spike_waveform_features=spike_waveform_features, + time=interval_time, + **predict_kwargs, ) - results = xr.concat(results, dim="intervals") + results.append(interval_result) + # Track which interval each time point belongs to + interval_labels.extend( + [interval_idx] * len(interval_result.time) + ) + # Concatenate along time dimension instead of intervals dimension + if not results: + raise ValueError("All decoding intervals are empty") + results = xr.concat(results, dim="time") + # Add interval_labels as a coordinate for groupby/selection operations + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Save discrete transition and initial conditions results["initial_conditions"] = xr.DataArray( diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 6a773bce9..59a214881 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -18,6 +18,7 @@ import numpy as np import pandas as pd import xarray as xr +from scipy.ndimage import label from non_local_detector.models.base import SortedSpikesDetector from track_linearization import get_linearized_position @@ -195,7 +196,9 @@ def _run_decoder( classifier : SortedSpikesDetector Fitted classifier instance results : xr.Dataset - Decoding results with posteriors + Decoding results with posteriors. Results from multiple intervals + are concatenated along the time dimension with an interval_labels + coordinate to track which interval each time point belongs to. """ classifier = SortedSpikesDetector(**decoding_params) @@ -223,6 +226,14 @@ def _run_decoder( time=position_info.index.to_numpy(), **decoding_kwargs, ) + # Add interval_labels coordinate for consistency with predict branch + # label() returns 1-indexed labels; subtract 1 for 0-indexed intervals + # Result: -1 = outside intervals, 0, 1, 2... = interval index + labels, _ = label(~is_missing) + interval_labels = labels - 1 + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) else: VALID_FIT_KWARGS = [ "is_training", @@ -255,7 +266,10 @@ def _run_decoder( # We treat each decoding interval as a separate sequence results = [] - for interval_start, interval_end in decoding_interval: + interval_labels = [] + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() @@ -265,18 +279,28 @@ def _run_decoder( f"Interval {interval_start}:{interval_end} is empty" ) continue - results.append( - classifier.predict( - position_time=interval_time, - position=position_info.loc[interval_start:interval_end][ - position_variable_names - ].to_numpy(), - spike_times=spike_times, - time=interval_time, - **predict_kwargs, - ) + interval_result = classifier.predict( + position_time=interval_time, + position=position_info.loc[interval_start:interval_end][ + position_variable_names + ].to_numpy(), + spike_times=spike_times, + time=interval_time, + **predict_kwargs, ) - results = xr.concat(results, dim="intervals") + results.append(interval_result) + # Track which interval each time point belongs to + interval_labels.extend( + [interval_idx] * len(interval_result.time) + ) + # Concatenate along time dimension instead of intervals dimension + if not results: + raise ValueError("All decoding intervals are empty") + results = xr.concat(results, dim="time") + # Add interval_labels as a coordinate for groupby/selection operations + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Save discrete transition and initial conditions results["initial_conditions"] = xr.DataArray( diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index f6dbf3349..24d70ffc1 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -343,11 +343,15 @@ def decode_interval(common, mini_dict): raw_begin = (common.IntervalList & 'interval_list_name LIKE "raw%"').fetch1( "valid_times" )[0][0] + # Use a subset of the encoding interval (raw_begin+2 to raw_begin+13) + # This creates gaps at start and end, ensuring that when + # estimate_decoding_params=True, there are time points outside the + # decoding interval that will get interval_labels=-1 common.IntervalList.insert1( { **mini_dict, "interval_list_name": decode_interval_name, - "valid_times": [[raw_begin, raw_begin + 15]], + "valid_times": [[raw_begin + 2, raw_begin + 13]], }, skip_duplicates=True, ) @@ -599,22 +603,21 @@ def create_fake_decoding_results(n_time=100, n_position_bins=50, n_states=2): posterior[t] /= posterior[t].sum() # Create all expected coordinates for decoding results + # Note: Use "states" (plural) as the dimension name to match + # the real non_local_detector output format results = xr.Dataset( { - "posterior": (["time", "position", "state"], posterior), - "likelihood": (["time", "position", "state"], posterior * 0.8), + "posterior": (["time", "position", "states"], posterior), + "likelihood": (["time", "position", "states"], posterior * 0.8), }, coords={ "time": time, "position": position_bins, - "state": states, - "state_names": ("state", state_names), - # Additional coordinates expected by tests - "states": ("state", states), # Alias for state + "states": state_names, # Dimension coordinate with state names + "state_ind": ("states", states), # State indices "state_bins": ("position", position_bins), # Alias for position - "state_ind": ("state", states), # State indices - "encoding_groups": ("state", state_names), # Encoding group names - "environments": ("state", state_names), # Environment names + "encoding_groups": ("states", state_names), # Encoding group names + "environments": ("states", state_names), # Environment names }, ) @@ -777,6 +780,7 @@ def mock_clusterless_decoder(): the real _run_decoder method. """ import xarray as xr + from scipy.ndimage import label def _mock_run_decoder( self, @@ -793,11 +797,81 @@ def _mock_run_decoder( This mocks the expensive non_local_detector operations (~220s) while preserving all the Spyglass logic in make(). + + Handles both estimate_decoding_params=True and False branches: + - True: Returns results for ALL time points, with interval_labels + using scipy.ndimage.label approach (-1 for outside intervals) + - False: Returns results only for interval time points, with + interval_labels using enumerate approach (0, 1, 2, ...) """ classifier = create_fake_classifier() - results = create_fake_decoding_results( - n_time=len(position_info), n_position_bins=50, n_states=2 - ) + + if key.get("estimate_decoding_params", False): + # estimate_decoding_params=True branch: + # Results span ALL time points in position_info + all_time = position_info.index.to_numpy() + + # Create is_missing mask (same as real code) + is_missing = np.ones(len(position_info), dtype=bool) + for interval_start, interval_end in decoding_interval: + is_missing[ + np.logical_and( + position_info.index >= interval_start, + position_info.index <= interval_end, + ) + ] = False + + # Create fake results for all time points + results = create_fake_decoding_results( + n_time=len(all_time), n_position_bins=50, n_states=2 + ) + results = results.assign_coords(time=all_time) + + # Create interval_labels using scipy.ndimage.label (same as real code) + labels_arr, _ = label(~is_missing) + interval_labels = labels_arr - 1 + + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) + else: + # estimate_decoding_params=False branch: + # Results only for time points within intervals + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Add metadata (same as real implementation) # initial_conditions: shape (n_states,) with explicit dims @@ -932,6 +1006,7 @@ def _mock_load_model(filename): def mock_sorted_spikes_decoder(): """Mock the _run_decoder helper for SortedSpikesDecodingV1.""" import xarray as xr + from scipy.ndimage import label def _mock_run_decoder( self, @@ -943,11 +1018,82 @@ def _mock_run_decoder( spike_times, decoding_interval, ): - """Mocked version that returns fake results instantly.""" + """Mocked version that returns fake results instantly. + + Handles both estimate_decoding_params=True and False branches: + - True: Returns results for ALL time points, with interval_labels + using scipy.ndimage.label approach (-1 for outside intervals) + - False: Returns results only for interval time points, with + interval_labels using enumerate approach (0, 1, 2, ...) + """ classifier = create_fake_classifier() - results = create_fake_decoding_results( - n_time=len(position_info), n_position_bins=50, n_states=2 - ) + + if key.get("estimate_decoding_params", False): + # estimate_decoding_params=True branch: + # Results span ALL time points in position_info + all_time = position_info.index.to_numpy() + + # Create is_missing mask (same as real code) + is_missing = np.ones(len(position_info), dtype=bool) + for interval_start, interval_end in decoding_interval: + is_missing[ + np.logical_and( + position_info.index >= interval_start, + position_info.index <= interval_end, + ) + ] = False + + # Create fake results for all time points + results = create_fake_decoding_results( + n_time=len(all_time), n_position_bins=50, n_states=2 + ) + results = results.assign_coords(time=all_time) + + # Create interval_labels using scipy.ndimage.label (same as real code) + labels_arr, _ = label(~is_missing) + interval_labels = labels_arr - 1 + + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) + else: + # estimate_decoding_params=False branch: + # Results only for time points within intervals + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Add metadata (same as real implementation) # initial_conditions: shape (n_states,) with explicit dims diff --git a/tests/decoding/test_intervals_removal.py b/tests/decoding/test_intervals_removal.py new file mode 100644 index 000000000..0b72f9a4f --- /dev/null +++ b/tests/decoding/test_intervals_removal.py @@ -0,0 +1,510 @@ +"""Tests for intervals dimension removal. + +These tests verify that decoding results are stored as a single time series +with interval tracking, rather than using the intervals dimension which causes +padding and memory waste. + +Tests cover both estimate_decoding_params=False (predict branch) and +estimate_decoding_params=True (estimate_parameters branch). +""" + +import numpy as np + + +def test_no_intervals_dimension_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, + mock_results_storage, +): + """Test that clusterless decoding results don't use intervals dimension.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks to ClusterlessDecodingV1 + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert selection + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + + # Run populate + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Verify that intervals is NOT a dimension + assert "intervals" not in results.dims, ( + "Results should not have 'intervals' dimension - " + "data should be concatenated along time instead" + ) + + # Verify that time is a dimension + assert "time" in results.dims, "Results should have 'time' dimension" + + # Verify that interval_labels exists as a coordinate or variable + assert ( + "interval_labels" in results.coords + or "interval_labels" in results.data_vars + ), ( + "Results should have 'interval_labels' to track which interval " + "each time point belongs to" + ) + + +def test_interval_labels_tracking_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that interval_labels correctly tracks intervals in clusterless decoding.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Get interval_labels + if "interval_labels" in results.coords: + interval_labels = results.coords["interval_labels"] + else: + interval_labels = results["interval_labels"] + + # Verify interval_labels has same length as time + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" + + # Verify interval_labels are integers starting from 0 + unique_labels = np.unique(interval_labels) + assert np.all( + unique_labels == np.arange(len(unique_labels)) + ), "interval_labels should be consecutive integers starting from 0" + + +def test_groupby_interval_labels_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that results can be grouped by interval_labels.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Test groupby operation + grouped = results.groupby("interval_labels") + + # Verify groupby works + assert grouped is not None, "Should be able to groupby interval_labels" + + # Verify we can iterate through groups + for label, group in grouped: + assert isinstance( + label, (int, np.integer) + ), "Group labels should be integers" + assert "time" in group.dims, "Each group should have time dimension" + + +def test_no_intervals_dimension_sorted_spikes( + decode_v1, + monkeypatch, + mock_sorted_spikes_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_spike_params_insert, + pop_pos_group, + pop_spikes_group, +): + """Test that sorted spikes decoding results don't use intervals dimension.""" + _ = pop_pos_group, pop_spikes_group # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_run_decoder", + mock_sorted_spikes_decoder, + ) + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_spike_params_insert, + "sorted_spikes_group_name": group_name, + "unit_filter_params_name": "default_exclusion", + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.sorted_spikes.SortedSpikesDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.sorted_spikes.SortedSpikesDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.sorted_spikes.SortedSpikesDecodingV1 & selection_key + results = table.fetch_results() + + # Verify that intervals is NOT a dimension + assert "intervals" not in results.dims, ( + "Results should not have 'intervals' dimension - " + "data should be concatenated along time instead" + ) + + # Verify that time is a dimension + assert "time" in results.dims, "Results should have 'time' dimension" + + # Verify that interval_labels exists + assert ( + "interval_labels" in results.coords + or "interval_labels" in results.data_vars + ), ( + "Results should have 'interval_labels' to track which interval " + "each time point belongs to" + ) + + +# ============================================================================ +# Tests for estimate_decoding_params=True branch +# ============================================================================ + + +def test_interval_labels_estimate_params_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test interval_labels when estimate_decoding_params=True (clusterless). + + When estimate_decoding_params=True, results span ALL time points and + interval_labels should be: + - -1 for time points outside any interval + - 0, 1, 2, ... for time points inside intervals + """ + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Verify interval_labels exists + assert "interval_labels" in results.coords, ( + "Results should have 'interval_labels' coordinate when " + "estimate_decoding_params=True" + ) + + # Get interval_labels + interval_labels = results.coords["interval_labels"].values + + # Verify interval_labels has same length as time + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" + + # Verify that -1 exists (for times outside intervals) + assert -1 in interval_labels, ( + "interval_labels should contain -1 for time points outside intervals " + "when estimate_decoding_params=True" + ) + + # Verify that non-negative labels exist (for times inside intervals) + assert np.any(interval_labels >= 0), ( + "interval_labels should contain non-negative values for time points " + "inside intervals" + ) + + # Verify labels are consecutive integers starting from 0 (excluding -1) + positive_labels = interval_labels[interval_labels >= 0] + unique_positive = np.unique(positive_labels) + expected_labels = np.arange(len(unique_positive)) + np.testing.assert_array_equal( + unique_positive, + expected_labels, + err_msg="Positive interval_labels should be consecutive integers from 0", + ) + + +def test_interval_labels_estimate_params_sorted_spikes( + decode_v1, + monkeypatch, + mock_sorted_spikes_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_spike_params_insert, + pop_pos_group, + pop_spikes_group, +): + """Test interval_labels when estimate_decoding_params=True (sorted spikes). + + When estimate_decoding_params=True, results span ALL time points and + interval_labels should be: + - -1 for time points outside any interval + - 0, 1, 2, ... for time points inside intervals + """ + _ = pop_pos_group, pop_spikes_group # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_run_decoder", + mock_sorted_spikes_decoder, + ) + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_spike_params_insert, + "sorted_spikes_group_name": group_name, + "unit_filter_params_name": "default_exclusion", + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.sorted_spikes.SortedSpikesDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.sorted_spikes.SortedSpikesDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.sorted_spikes.SortedSpikesDecodingV1 & selection_key + results = table.fetch_results() + + # Verify interval_labels exists + assert "interval_labels" in results.coords, ( + "Results should have 'interval_labels' coordinate when " + "estimate_decoding_params=True" + ) + + # Get interval_labels + interval_labels = results.coords["interval_labels"].values + + # Verify interval_labels has same length as time + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" + + # Verify that -1 exists (for times outside intervals) + assert -1 in interval_labels, ( + "interval_labels should contain -1 for time points outside intervals " + "when estimate_decoding_params=True" + ) + + # Verify that non-negative labels exist (for times inside intervals) + assert np.any(interval_labels >= 0), ( + "interval_labels should contain non-negative values for time points " + "inside intervals" + ) + + +def test_groupby_works_with_negative_labels( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that groupby works correctly with -1 labels. + + Users should be able to: + - Group by interval_labels to iterate through intervals + - Filter out -1 labels to get only interval data + """ + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Test groupby operation works + grouped = results.groupby("interval_labels") + assert grouped is not None, "Should be able to groupby interval_labels" + + # Verify we can iterate and get the -1 group + labels_seen = [] + for label, group in grouped: + labels_seen.append(label) + assert "time" in group.dims, "Each group should have time dimension" + + # Verify -1 is one of the groups + assert -1 in labels_seen, "Should have a group for -1 (outside intervals)" + + # Test filtering to only interval data + interval_data = results.where(results.interval_labels >= 0, drop=True) + assert len(interval_data.time) < len( + results.time + ), "Filtering to interval_labels >= 0 should reduce data size" + assert np.all( + interval_data.interval_labels >= 0 + ), "Filtered data should only have non-negative interval_labels" diff --git a/tests/decoding/test_intervals_removal_simple.py b/tests/decoding/test_intervals_removal_simple.py new file mode 100644 index 000000000..340f13241 --- /dev/null +++ b/tests/decoding/test_intervals_removal_simple.py @@ -0,0 +1,132 @@ +"""Simple unit test for intervals dimension removal. + +This test validates the xarray concatenation logic without requiring +database infrastructure or external dependencies. +""" + +import numpy as np +import xarray as xr + + +def test_concatenation_without_intervals_dimension(): + """Test that concatenating along time instead of intervals works correctly.""" + + # Simulate two intervals with different lengths (like real decoding) + interval1_time = np.arange(0, 10, 0.1) + interval2_time = np.arange(15, 22, 0.1) + n_position = 50 + + # Create results for each interval + results = [] + interval_labels = [] + + for interval_idx, interval_time in enumerate( + [interval1_time, interval2_time] + ): + n_time = len(interval_time) + interval_result = xr.Dataset( + { + "posterior": ( + ["time", "position"], + np.random.rand(n_time, n_position), + ), + }, + coords={"time": interval_time, "position": np.arange(n_position)}, + ) + results.append(interval_result) + interval_labels.extend([interval_idx] * n_time) + + # Concatenate along time dimension (new approach) + concatenated = xr.concat(results, dim="time") + concatenated = concatenated.assign_coords( + interval_labels=("time", interval_labels) + ) + + # Verify structure + assert ( + "intervals" not in concatenated.dims + ), "Should not have intervals dimension" + assert "time" in concatenated.dims, "Should have time dimension" + assert ( + "interval_labels" in concatenated.coords + ), "Should have interval_labels coordinate" + + # Verify shape - should be (total_time, n_position) not (n_intervals, max_time, n_position) + expected_time_points = len(interval1_time) + len(interval2_time) + assert concatenated.posterior.shape == ( + expected_time_points, + n_position, + ), f"Expected shape ({expected_time_points}, {n_position}), got {concatenated.posterior.shape}" + + # Verify interval_labels + assert ( + len(concatenated.coords["interval_labels"]) == expected_time_points + ), "interval_labels should have same length as time" + unique_labels = np.unique(concatenated.coords["interval_labels"].values) + assert len(unique_labels) == 2, "Should have 2 unique interval labels" + assert list(unique_labels) == [0, 1], "Interval labels should be [0, 1]" + + # Verify groupby works + grouped = concatenated.groupby("interval_labels") + group_sizes = {label: len(group.time) for label, group in grouped} + assert group_sizes[0] == len( + interval1_time + ), f"Interval 0 should have {len(interval1_time)} time points" + assert group_sizes[1] == len( + interval2_time + ), f"Interval 1 should have {len(interval2_time)} time points" + + print("āœ“ All checks passed!") + + +def test_memory_efficiency(): + """Test that new approach uses less memory than intervals dimension.""" + + # Create two intervals with different lengths + interval1_data = xr.Dataset( + { + "posterior": (["time", "position"], np.random.rand(100, 50)), + }, + coords={"time": np.arange(100), "position": np.arange(50)}, + ) + + interval2_data = xr.Dataset( + { + "posterior": (["time", "position"], np.random.rand(70, 50)), + }, + coords={"time": np.arange(70), "position": np.arange(50)}, + ) + + # Old approach: concat with intervals dimension (creates padding) + old_approach = xr.concat([interval1_data, interval2_data], dim="intervals") + + # New approach: concat along time + interval2_shifted = interval2_data.assign_coords( + time=interval2_data.time + 100 + ) + new_approach = xr.concat([interval1_data, interval2_shifted], dim="time") + new_approach = new_approach.assign_coords( + interval_labels=("time", [0] * 100 + [1] * 70) + ) + + # Old approach has shape (2, 100, 50) = 10000 values with 1500 padding zeros + # New approach has shape (170, 50) = 8500 values with no padding + assert ( + old_approach.posterior.size > new_approach.posterior.size + ), "Old approach should use more memory due to padding" + + # Calculate memory savings + old_bytes = old_approach.nbytes + new_bytes = new_approach.nbytes + savings_percent = ((old_bytes - new_bytes) / old_bytes) * 100 + + print( + f"āœ“ Memory savings: {savings_percent:.1f}% ({old_bytes} -> {new_bytes} bytes)" + ) + assert savings_percent > 0, "New approach should save memory" + + +if __name__ == "__main__": + test_concatenation_without_intervals_dimension() + test_memory_efficiency() + print("\nāœ“āœ“ All simple tests passed!")