From dd14b9402f40b5416761f206a2695a8493176b39 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:15:12 +0000 Subject: [PATCH 1/9] Initial plan From 3aaebec138f0e7a7fb8ddca300961eec393ed0a7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:24:20 +0000 Subject: [PATCH 2/9] Remove intervals dimension and add interval_labels coordinate Co-authored-by: edeno <8053989+edeno@users.noreply.github.com> --- src/spyglass/decoding/v1/clusterless.py | 37 +-- src/spyglass/decoding/v1/sorted_spikes.py | 35 ++- tests/decoding/conftest.py | 76 ++++++- tests/decoding/test_intervals_removal.py | 262 ++++++++++++++++++++++ 4 files changed, 381 insertions(+), 29 deletions(-) create mode 100644 tests/decoding/test_intervals_removal.py diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 3c753268c..738298bfb 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -300,7 +300,10 @@ def _run_decoder( # We treat each decoding interval as a separate sequence results = [] - for interval_start, interval_end in decoding_interval: + interval_labels = [] + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() @@ -310,19 +313,27 @@ def _run_decoder( f"Interval {interval_start}:{interval_end} is empty" ) continue - results.append( - classifier.predict( - position_time=interval_time, - position=position_info.loc[interval_start:interval_end][ - position_variable_names - ].to_numpy(), - spike_times=spike_times, - spike_waveform_features=spike_waveform_features, - time=interval_time, - **predict_kwargs, - ) + interval_result = classifier.predict( + position_time=interval_time, + position=position_info.loc[interval_start:interval_end][ + position_variable_names + ].to_numpy(), + spike_times=spike_times, + spike_waveform_features=spike_waveform_features, + time=interval_time, + **predict_kwargs, ) - results = xr.concat(results, dim="intervals") + results.append(interval_result) + # Track which interval each time point belongs to + interval_labels.extend( + [interval_idx] * len(interval_result.time) + ) + # Concatenate along time dimension instead of intervals dimension + results = xr.concat(results, dim="time") + # Add interval_labels as a coordinate for groupby/selection operations + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Save discrete transition and initial conditions results["initial_conditions"] = xr.DataArray( diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 6a773bce9..2494e70da 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -255,7 +255,10 @@ def _run_decoder( # We treat each decoding interval as a separate sequence results = [] - for interval_start, interval_end in decoding_interval: + interval_labels = [] + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() @@ -265,18 +268,26 @@ def _run_decoder( f"Interval {interval_start}:{interval_end} is empty" ) continue - results.append( - classifier.predict( - position_time=interval_time, - position=position_info.loc[interval_start:interval_end][ - position_variable_names - ].to_numpy(), - spike_times=spike_times, - time=interval_time, - **predict_kwargs, - ) + interval_result = classifier.predict( + position_time=interval_time, + position=position_info.loc[interval_start:interval_end][ + position_variable_names + ].to_numpy(), + spike_times=spike_times, + time=interval_time, + **predict_kwargs, ) - results = xr.concat(results, dim="intervals") + results.append(interval_result) + # Track which interval each time point belongs to + interval_labels.extend( + [interval_idx] * len(interval_result.time) + ) + # Concatenate along time dimension instead of intervals dimension + results = xr.concat(results, dim="time") + # Add interval_labels as a coordinate for groupby/selection operations + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Save discrete transition and initial conditions results["initial_conditions"] = xr.DataArray( diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index f6dbf3349..f6b66d32d 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -795,8 +795,42 @@ def _mock_run_decoder( while preserving all the Spyglass logic in make(). """ classifier = create_fake_classifier() - results = create_fake_decoding_results( - n_time=len(position_info), n_position_bins=50, n_states=2 + + # Simulate multiple intervals to test the concatenation logic + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) ) # Add metadata (same as real implementation) @@ -945,8 +979,42 @@ def _mock_run_decoder( ): """Mocked version that returns fake results instantly.""" classifier = create_fake_classifier() - results = create_fake_decoding_results( - n_time=len(position_info), n_position_bins=50, n_states=2 + + # Simulate multiple intervals to test the concatenation logic + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) ) # Add metadata (same as real implementation) diff --git a/tests/decoding/test_intervals_removal.py b/tests/decoding/test_intervals_removal.py new file mode 100644 index 000000000..7a150b882 --- /dev/null +++ b/tests/decoding/test_intervals_removal.py @@ -0,0 +1,262 @@ +"""Tests for intervals dimension removal. + +These tests verify that decoding results are stored as a single time series +with interval tracking, rather than using the intervals dimension which causes +padding and memory waste. +""" + +import numpy as np +import pytest +import xarray as xr + + +def test_no_intervals_dimension_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, + mock_results_storage, +): + """Test that clusterless decoding results don't use intervals dimension.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks to ClusterlessDecodingV1 + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert selection + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + + # Run populate + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Verify that intervals is NOT a dimension + assert "intervals" not in results.dims, ( + "Results should not have 'intervals' dimension - " + "data should be concatenated along time instead" + ) + + # Verify that time is a dimension + assert "time" in results.dims, "Results should have 'time' dimension" + + # Verify that interval_labels exists as a coordinate or variable + assert ( + "interval_labels" in results.coords or "interval_labels" in results.data_vars + ), ( + "Results should have 'interval_labels' to track which interval " + "each time point belongs to" + ) + + +def test_interval_labels_tracking_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that interval_labels correctly tracks intervals in clusterless decoding.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Get interval_labels + if "interval_labels" in results.coords: + interval_labels = results.coords["interval_labels"] + else: + interval_labels = results["interval_labels"] + + # Verify interval_labels has same length as time + assert len(interval_labels) == len(results.time), ( + "interval_labels should have same length as time dimension" + ) + + # Verify interval_labels are integers starting from 0 + unique_labels = np.unique(interval_labels) + assert np.all(unique_labels == np.arange(len(unique_labels))), ( + "interval_labels should be consecutive integers starting from 0" + ) + + +def test_groupby_interval_labels_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that results can be grouped by interval_labels.""" + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Test groupby operation + grouped = results.groupby("interval_labels") + + # Verify groupby works + assert grouped is not None, "Should be able to groupby interval_labels" + + # Verify we can iterate through groups + for label, group in grouped: + assert isinstance(label, (int, np.integer)), ( + "Group labels should be integers" + ) + assert "time" in group.dims, "Each group should have time dimension" + + +def test_no_intervals_dimension_sorted_spikes( + decode_v1, + monkeypatch, + mock_sorted_spikes_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_spike_params_insert, + pop_pos_group, + pop_spikes_group, +): + """Test that sorted spikes decoding results don't use intervals dimension.""" + _ = pop_pos_group, pop_spikes_group # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_run_decoder", + mock_sorted_spikes_decoder, + ) + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key + selection_key = { + **decode_sel_key, + **decode_spike_params_insert, + "sorted_spikes_group_name": group_name, + "unit_filter_params_name": "default_exclusion", + "estimate_decoding_params": False, + } + + # Insert and populate + decode_v1.sorted_spikes.SortedSpikesDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.sorted_spikes.SortedSpikesDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.sorted_spikes.SortedSpikesDecodingV1 & selection_key + results = table.fetch_results() + + # Verify that intervals is NOT a dimension + assert "intervals" not in results.dims, ( + "Results should not have 'intervals' dimension - " + "data should be concatenated along time instead" + ) + + # Verify that time is a dimension + assert "time" in results.dims, "Results should have 'time' dimension" + + # Verify that interval_labels exists + assert ( + "interval_labels" in results.coords or "interval_labels" in results.data_vars + ), ( + "Results should have 'interval_labels' to track which interval " + "each time point belongs to" + ) From 5dd72b7b6a1b6f352990724689f34055f43c8ef4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 20 Nov 2025 20:27:59 +0000 Subject: [PATCH 3/9] Update docstrings and format code with black Co-authored-by: edeno <8053989+edeno@users.noreply.github.com> --- src/spyglass/decoding/v1/clusterless.py | 4 +- src/spyglass/decoding/v1/sorted_spikes.py | 4 +- tests/decoding/conftest.py | 24 ++-- tests/decoding/test_intervals_removal.py | 28 ++-- .../decoding/test_intervals_removal_simple.py | 132 ++++++++++++++++++ 5 files changed, 165 insertions(+), 27 deletions(-) create mode 100644 tests/decoding/test_intervals_removal_simple.py diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 738298bfb..b7661360d 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -237,7 +237,9 @@ def _run_decoder( classifier : ClusterlessDetector Fitted classifier instance results : xr.Dataset - Decoding results with posteriors + Decoding results with posteriors. Results from multiple intervals + are concatenated along the time dimension with an interval_labels + coordinate to track which interval each time point belongs to. """ classifier = ClusterlessDetector(**decoding_params) diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 2494e70da..10dd95e6f 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -195,7 +195,9 @@ def _run_decoder( classifier : SortedSpikesDetector Fitted classifier instance results : xr.Dataset - Decoding results with posteriors + Decoding results with posteriors. Results from multiple intervals + are concatenated along the time dimension with an interval_labels + coordinate to track which interval each time point belongs to. """ classifier = SortedSpikesDetector(**decoding_params) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index f6b66d32d..f003d1453 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -795,11 +795,11 @@ def _mock_run_decoder( while preserving all the Spyglass logic in make(). """ classifier = create_fake_classifier() - + # Simulate multiple intervals to test the concatenation logic results_list = [] interval_labels = [] - + for interval_idx, (interval_start, interval_end) in enumerate( decoding_interval ): @@ -807,10 +807,10 @@ def _mock_run_decoder( interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() - + if interval_time.size == 0: continue - + # Create fake results for this interval interval_results = create_fake_decoding_results( n_time=len(interval_time), n_position_bins=50, n_states=2 @@ -821,13 +821,13 @@ def _mock_run_decoder( ) results_list.append(interval_results) interval_labels.extend([interval_idx] * len(interval_time)) - + # Concatenate along time dimension (as the real code now does) if len(results_list) == 1: results = results_list[0] else: results = xr.concat(results_list, dim="time") - + # Add interval_labels coordinate (as the real code now does) results = results.assign_coords( interval_labels=("time", interval_labels) @@ -979,11 +979,11 @@ def _mock_run_decoder( ): """Mocked version that returns fake results instantly.""" classifier = create_fake_classifier() - + # Simulate multiple intervals to test the concatenation logic results_list = [] interval_labels = [] - + for interval_idx, (interval_start, interval_end) in enumerate( decoding_interval ): @@ -991,10 +991,10 @@ def _mock_run_decoder( interval_time = position_info.loc[ interval_start:interval_end ].index.to_numpy() - + if interval_time.size == 0: continue - + # Create fake results for this interval interval_results = create_fake_decoding_results( n_time=len(interval_time), n_position_bins=50, n_states=2 @@ -1005,13 +1005,13 @@ def _mock_run_decoder( ) results_list.append(interval_results) interval_labels.extend([interval_idx] * len(interval_time)) - + # Concatenate along time dimension (as the real code now does) if len(results_list) == 1: results = results_list[0] else: results = xr.concat(results_list, dim="time") - + # Add interval_labels coordinate (as the real code now does) results = results.assign_coords( interval_labels=("time", interval_labels) diff --git a/tests/decoding/test_intervals_removal.py b/tests/decoding/test_intervals_removal.py index 7a150b882..b33cc40c6 100644 --- a/tests/decoding/test_intervals_removal.py +++ b/tests/decoding/test_intervals_removal.py @@ -69,7 +69,8 @@ def test_no_intervals_dimension_clusterless( # Verify that interval_labels exists as a coordinate or variable assert ( - "interval_labels" in results.coords or "interval_labels" in results.data_vars + "interval_labels" in results.coords + or "interval_labels" in results.data_vars ), ( "Results should have 'interval_labels' to track which interval " "each time point belongs to" @@ -128,15 +129,15 @@ def test_interval_labels_tracking_clusterless( interval_labels = results["interval_labels"] # Verify interval_labels has same length as time - assert len(interval_labels) == len(results.time), ( - "interval_labels should have same length as time dimension" - ) + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" # Verify interval_labels are integers starting from 0 unique_labels = np.unique(interval_labels) - assert np.all(unique_labels == np.arange(len(unique_labels))), ( - "interval_labels should be consecutive integers starting from 0" - ) + assert np.all( + unique_labels == np.arange(len(unique_labels)) + ), "interval_labels should be consecutive integers starting from 0" def test_groupby_interval_labels_clusterless( @@ -186,15 +187,15 @@ def test_groupby_interval_labels_clusterless( # Test groupby operation grouped = results.groupby("interval_labels") - + # Verify groupby works assert grouped is not None, "Should be able to groupby interval_labels" - + # Verify we can iterate through groups for label, group in grouped: - assert isinstance(label, (int, np.integer)), ( - "Group labels should be integers" - ) + assert isinstance( + label, (int, np.integer) + ), "Group labels should be integers" assert "time" in group.dims, "Each group should have time dimension" @@ -255,7 +256,8 @@ def test_no_intervals_dimension_sorted_spikes( # Verify that interval_labels exists assert ( - "interval_labels" in results.coords or "interval_labels" in results.data_vars + "interval_labels" in results.coords + or "interval_labels" in results.data_vars ), ( "Results should have 'interval_labels' to track which interval " "each time point belongs to" diff --git a/tests/decoding/test_intervals_removal_simple.py b/tests/decoding/test_intervals_removal_simple.py new file mode 100644 index 000000000..340f13241 --- /dev/null +++ b/tests/decoding/test_intervals_removal_simple.py @@ -0,0 +1,132 @@ +"""Simple unit test for intervals dimension removal. + +This test validates the xarray concatenation logic without requiring +database infrastructure or external dependencies. +""" + +import numpy as np +import xarray as xr + + +def test_concatenation_without_intervals_dimension(): + """Test that concatenating along time instead of intervals works correctly.""" + + # Simulate two intervals with different lengths (like real decoding) + interval1_time = np.arange(0, 10, 0.1) + interval2_time = np.arange(15, 22, 0.1) + n_position = 50 + + # Create results for each interval + results = [] + interval_labels = [] + + for interval_idx, interval_time in enumerate( + [interval1_time, interval2_time] + ): + n_time = len(interval_time) + interval_result = xr.Dataset( + { + "posterior": ( + ["time", "position"], + np.random.rand(n_time, n_position), + ), + }, + coords={"time": interval_time, "position": np.arange(n_position)}, + ) + results.append(interval_result) + interval_labels.extend([interval_idx] * n_time) + + # Concatenate along time dimension (new approach) + concatenated = xr.concat(results, dim="time") + concatenated = concatenated.assign_coords( + interval_labels=("time", interval_labels) + ) + + # Verify structure + assert ( + "intervals" not in concatenated.dims + ), "Should not have intervals dimension" + assert "time" in concatenated.dims, "Should have time dimension" + assert ( + "interval_labels" in concatenated.coords + ), "Should have interval_labels coordinate" + + # Verify shape - should be (total_time, n_position) not (n_intervals, max_time, n_position) + expected_time_points = len(interval1_time) + len(interval2_time) + assert concatenated.posterior.shape == ( + expected_time_points, + n_position, + ), f"Expected shape ({expected_time_points}, {n_position}), got {concatenated.posterior.shape}" + + # Verify interval_labels + assert ( + len(concatenated.coords["interval_labels"]) == expected_time_points + ), "interval_labels should have same length as time" + unique_labels = np.unique(concatenated.coords["interval_labels"].values) + assert len(unique_labels) == 2, "Should have 2 unique interval labels" + assert list(unique_labels) == [0, 1], "Interval labels should be [0, 1]" + + # Verify groupby works + grouped = concatenated.groupby("interval_labels") + group_sizes = {label: len(group.time) for label, group in grouped} + assert group_sizes[0] == len( + interval1_time + ), f"Interval 0 should have {len(interval1_time)} time points" + assert group_sizes[1] == len( + interval2_time + ), f"Interval 1 should have {len(interval2_time)} time points" + + print("āœ“ All checks passed!") + + +def test_memory_efficiency(): + """Test that new approach uses less memory than intervals dimension.""" + + # Create two intervals with different lengths + interval1_data = xr.Dataset( + { + "posterior": (["time", "position"], np.random.rand(100, 50)), + }, + coords={"time": np.arange(100), "position": np.arange(50)}, + ) + + interval2_data = xr.Dataset( + { + "posterior": (["time", "position"], np.random.rand(70, 50)), + }, + coords={"time": np.arange(70), "position": np.arange(50)}, + ) + + # Old approach: concat with intervals dimension (creates padding) + old_approach = xr.concat([interval1_data, interval2_data], dim="intervals") + + # New approach: concat along time + interval2_shifted = interval2_data.assign_coords( + time=interval2_data.time + 100 + ) + new_approach = xr.concat([interval1_data, interval2_shifted], dim="time") + new_approach = new_approach.assign_coords( + interval_labels=("time", [0] * 100 + [1] * 70) + ) + + # Old approach has shape (2, 100, 50) = 10000 values with 1500 padding zeros + # New approach has shape (170, 50) = 8500 values with no padding + assert ( + old_approach.posterior.size > new_approach.posterior.size + ), "Old approach should use more memory due to padding" + + # Calculate memory savings + old_bytes = old_approach.nbytes + new_bytes = new_approach.nbytes + savings_percent = ((old_bytes - new_bytes) / old_bytes) * 100 + + print( + f"āœ“ Memory savings: {savings_percent:.1f}% ({old_bytes} -> {new_bytes} bytes)" + ) + assert savings_percent > 0, "New approach should save memory" + + +if __name__ == "__main__": + test_concatenation_without_intervals_dimension() + test_memory_efficiency() + print("\nāœ“āœ“ All simple tests passed!") From 281d9490c4f2366d3641e636c1a9a715675f648c Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 21 Nov 2025 07:06:04 -0800 Subject: [PATCH 4/9] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/spyglass/decoding/v1/clusterless.py | 2 ++ src/spyglass/decoding/v1/sorted_spikes.py | 2 ++ tests/decoding/test_intervals_removal.py | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index b7661360d..a394029a6 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -331,6 +331,8 @@ def _run_decoder( [interval_idx] * len(interval_result.time) ) # Concatenate along time dimension instead of intervals dimension + if not results: + raise ValueError("All decoding intervals are empty") results = xr.concat(results, dim="time") # Add interval_labels as a coordinate for groupby/selection operations results = results.assign_coords( diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 10dd95e6f..3f36956d7 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -285,6 +285,8 @@ def _run_decoder( [interval_idx] * len(interval_result.time) ) # Concatenate along time dimension instead of intervals dimension + if not results: + raise ValueError("All decoding intervals are empty") results = xr.concat(results, dim="time") # Add interval_labels as a coordinate for groupby/selection operations results = results.assign_coords( diff --git a/tests/decoding/test_intervals_removal.py b/tests/decoding/test_intervals_removal.py index b33cc40c6..5907b826f 100644 --- a/tests/decoding/test_intervals_removal.py +++ b/tests/decoding/test_intervals_removal.py @@ -6,8 +6,8 @@ """ import numpy as np -import pytest -import xarray as xr + + def test_no_intervals_dimension_clusterless( From 5bbb3fa270a5ab682b8570fbe294dbd87d92975e Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 21 Nov 2025 14:14:33 -0500 Subject: [PATCH 5/9] Add interval_labels coordinate to decoding results Introduces an interval_labels coordinate to the results in both ClusterlessDecodingV1 and SortedSpikesDecodingV1 for consistency with the predict branch. This uses scipy.ndimage.label to identify intervals, ensuring that results outside intervals are marked as -1 and intervals are 0-indexed. --- src/spyglass/decoding/v1/clusterless.py | 9 +++++++++ src/spyglass/decoding/v1/sorted_spikes.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index a394029a6..b477bb253 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -19,6 +19,7 @@ import pandas as pd import xarray as xr from non_local_detector.models.base import ClusterlessDetector +from scipy.ndimage import label from track_linearization import get_linearized_position from spyglass.common.common_interval import IntervalList # noqa: F401 @@ -268,6 +269,14 @@ def _run_decoder( time=position_info.index.to_numpy(), **decoding_kwargs, ) + # Add interval_labels coordinate for consistency with predict branch + # label() returns 1-indexed labels; subtract 1 for 0-indexed intervals + # Result: -1 = outside intervals, 0, 1, 2... = interval index + labels, _ = label(~is_missing) + interval_labels = labels - 1 + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) else: VALID_FIT_KWARGS = [ "is_training", diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 3f36956d7..59a214881 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -18,6 +18,7 @@ import numpy as np import pandas as pd import xarray as xr +from scipy.ndimage import label from non_local_detector.models.base import SortedSpikesDetector from track_linearization import get_linearized_position @@ -225,6 +226,14 @@ def _run_decoder( time=position_info.index.to_numpy(), **decoding_kwargs, ) + # Add interval_labels coordinate for consistency with predict branch + # label() returns 1-indexed labels; subtract 1 for 0-indexed intervals + # Result: -1 = outside intervals, 0, 1, 2... = interval index + labels, _ = label(~is_missing) + interval_labels = labels - 1 + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) else: VALID_FIT_KWARGS = [ "is_training", From e54d4dd7d9a428bc048a4a28c17d9c04bd9b0d53 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 21 Nov 2025 17:26:43 -0500 Subject: [PATCH 6/9] Add tests for interval_labels with estimate_decoding_params Expanded mock decoders and tests to cover both estimate_decoding_params=True and False branches. Tests now verify interval_labels behavior for all time points, including handling of -1 for points outside intervals and groupby/filtering operations. Mock decoders updated to match real logic for interval_labels assignment. --- tests/decoding/conftest.py | 205 +++++++++++++------ tests/decoding/test_intervals_removal.py | 250 ++++++++++++++++++++++- 2 files changed, 388 insertions(+), 67 deletions(-) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index f003d1453..d51f58ee2 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -777,6 +777,7 @@ def mock_clusterless_decoder(): the real _run_decoder method. """ import xarray as xr + from scipy.ndimage import label def _mock_run_decoder( self, @@ -793,45 +794,81 @@ def _mock_run_decoder( This mocks the expensive non_local_detector operations (~220s) while preserving all the Spyglass logic in make(). + + Handles both estimate_decoding_params=True and False branches: + - True: Returns results for ALL time points, with interval_labels + using scipy.ndimage.label approach (-1 for outside intervals) + - False: Returns results only for interval time points, with + interval_labels using enumerate approach (0, 1, 2, ...) """ classifier = create_fake_classifier() - # Simulate multiple intervals to test the concatenation logic - results_list = [] - interval_labels = [] - - for interval_idx, (interval_start, interval_end) in enumerate( - decoding_interval - ): - # Get time points for this interval - interval_time = position_info.loc[ - interval_start:interval_end - ].index.to_numpy() + if key.get("estimate_decoding_params", False): + # estimate_decoding_params=True branch: + # Results span ALL time points in position_info + all_time = position_info.index.to_numpy() + + # Create is_missing mask (same as real code) + is_missing = np.ones(len(position_info), dtype=bool) + for interval_start, interval_end in decoding_interval: + is_missing[ + np.logical_and( + position_info.index >= interval_start, + position_info.index <= interval_end, + ) + ] = False + + # Create fake results for all time points + results = create_fake_decoding_results( + n_time=len(all_time), n_position_bins=50, n_states=2 + ) + results = results.assign_coords(time=all_time) - if interval_time.size == 0: - continue + # Create interval_labels using scipy.ndimage.label (same as real code) + labels_arr, _ = label(~is_missing) + interval_labels = labels_arr - 1 - # Create fake results for this interval - interval_results = create_fake_decoding_results( - n_time=len(interval_time), n_position_bins=50, n_states=2 + results = results.assign_coords( + interval_labels=("time", interval_labels) ) - # Update time coordinates to match actual interval times - interval_results = interval_results.assign_coords( - time=interval_time - ) - results_list.append(interval_results) - interval_labels.extend([interval_idx] * len(interval_time)) - - # Concatenate along time dimension (as the real code now does) - if len(results_list) == 1: - results = results_list[0] else: - results = xr.concat(results_list, dim="time") - - # Add interval_labels coordinate (as the real code now does) - results = results.assign_coords( - interval_labels=("time", interval_labels) - ) + # estimate_decoding_params=False branch: + # Results only for time points within intervals + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Add metadata (same as real implementation) # initial_conditions: shape (n_states,) with explicit dims @@ -966,6 +1003,7 @@ def _mock_load_model(filename): def mock_sorted_spikes_decoder(): """Mock the _run_decoder helper for SortedSpikesDecodingV1.""" import xarray as xr + from scipy.ndimage import label def _mock_run_decoder( self, @@ -977,45 +1015,82 @@ def _mock_run_decoder( spike_times, decoding_interval, ): - """Mocked version that returns fake results instantly.""" - classifier = create_fake_classifier() + """Mocked version that returns fake results instantly. - # Simulate multiple intervals to test the concatenation logic - results_list = [] - interval_labels = [] + Handles both estimate_decoding_params=True and False branches: + - True: Returns results for ALL time points, with interval_labels + using scipy.ndimage.label approach (-1 for outside intervals) + - False: Returns results only for interval time points, with + interval_labels using enumerate approach (0, 1, 2, ...) + """ + classifier = create_fake_classifier() - for interval_idx, (interval_start, interval_end) in enumerate( - decoding_interval - ): - # Get time points for this interval - interval_time = position_info.loc[ - interval_start:interval_end - ].index.to_numpy() + if key.get("estimate_decoding_params", False): + # estimate_decoding_params=True branch: + # Results span ALL time points in position_info + all_time = position_info.index.to_numpy() + + # Create is_missing mask (same as real code) + is_missing = np.ones(len(position_info), dtype=bool) + for interval_start, interval_end in decoding_interval: + is_missing[ + np.logical_and( + position_info.index >= interval_start, + position_info.index <= interval_end, + ) + ] = False + + # Create fake results for all time points + results = create_fake_decoding_results( + n_time=len(all_time), n_position_bins=50, n_states=2 + ) + results = results.assign_coords(time=all_time) - if interval_time.size == 0: - continue + # Create interval_labels using scipy.ndimage.label (same as real code) + labels_arr, _ = label(~is_missing) + interval_labels = labels_arr - 1 - # Create fake results for this interval - interval_results = create_fake_decoding_results( - n_time=len(interval_time), n_position_bins=50, n_states=2 + results = results.assign_coords( + interval_labels=("time", interval_labels) ) - # Update time coordinates to match actual interval times - interval_results = interval_results.assign_coords( - time=interval_time - ) - results_list.append(interval_results) - interval_labels.extend([interval_idx] * len(interval_time)) - - # Concatenate along time dimension (as the real code now does) - if len(results_list) == 1: - results = results_list[0] else: - results = xr.concat(results_list, dim="time") - - # Add interval_labels coordinate (as the real code now does) - results = results.assign_coords( - interval_labels=("time", interval_labels) - ) + # estimate_decoding_params=False branch: + # Results only for time points within intervals + results_list = [] + interval_labels = [] + + for interval_idx, (interval_start, interval_end) in enumerate( + decoding_interval + ): + # Get time points for this interval + interval_time = position_info.loc[ + interval_start:interval_end + ].index.to_numpy() + + if interval_time.size == 0: + continue + + # Create fake results for this interval + interval_results = create_fake_decoding_results( + n_time=len(interval_time), n_position_bins=50, n_states=2 + ) + # Update time coordinates to match actual interval times + interval_results = interval_results.assign_coords( + time=interval_time + ) + results_list.append(interval_results) + interval_labels.extend([interval_idx] * len(interval_time)) + + # Concatenate along time dimension (as the real code now does) + if len(results_list) == 1: + results = results_list[0] + else: + results = xr.concat(results_list, dim="time") + + # Add interval_labels coordinate (as the real code now does) + results = results.assign_coords( + interval_labels=("time", interval_labels) + ) # Add metadata (same as real implementation) # initial_conditions: shape (n_states,) with explicit dims diff --git a/tests/decoding/test_intervals_removal.py b/tests/decoding/test_intervals_removal.py index 5907b826f..0b72f9a4f 100644 --- a/tests/decoding/test_intervals_removal.py +++ b/tests/decoding/test_intervals_removal.py @@ -3,13 +3,14 @@ These tests verify that decoding results are stored as a single time series with interval tracking, rather than using the intervals dimension which causes padding and memory waste. + +Tests cover both estimate_decoding_params=False (predict branch) and +estimate_decoding_params=True (estimate_parameters branch). """ import numpy as np - - def test_no_intervals_dimension_clusterless( decode_v1, monkeypatch, @@ -262,3 +263,248 @@ def test_no_intervals_dimension_sorted_spikes( "Results should have 'interval_labels' to track which interval " "each time point belongs to" ) + + +# ============================================================================ +# Tests for estimate_decoding_params=True branch +# ============================================================================ + + +def test_interval_labels_estimate_params_clusterless( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test interval_labels when estimate_decoding_params=True (clusterless). + + When estimate_decoding_params=True, results span ALL time points and + interval_labels should be: + - -1 for time points outside any interval + - 0, 1, 2, ... for time points inside intervals + """ + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Verify interval_labels exists + assert "interval_labels" in results.coords, ( + "Results should have 'interval_labels' coordinate when " + "estimate_decoding_params=True" + ) + + # Get interval_labels + interval_labels = results.coords["interval_labels"].values + + # Verify interval_labels has same length as time + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" + + # Verify that -1 exists (for times outside intervals) + assert -1 in interval_labels, ( + "interval_labels should contain -1 for time points outside intervals " + "when estimate_decoding_params=True" + ) + + # Verify that non-negative labels exist (for times inside intervals) + assert np.any(interval_labels >= 0), ( + "interval_labels should contain non-negative values for time points " + "inside intervals" + ) + + # Verify labels are consecutive integers starting from 0 (excluding -1) + positive_labels = interval_labels[interval_labels >= 0] + unique_positive = np.unique(positive_labels) + expected_labels = np.arange(len(unique_positive)) + np.testing.assert_array_equal( + unique_positive, + expected_labels, + err_msg="Positive interval_labels should be consecutive integers from 0", + ) + + +def test_interval_labels_estimate_params_sorted_spikes( + decode_v1, + monkeypatch, + mock_sorted_spikes_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_spike_params_insert, + pop_pos_group, + pop_spikes_group, +): + """Test interval_labels when estimate_decoding_params=True (sorted spikes). + + When estimate_decoding_params=True, results span ALL time points and + interval_labels should be: + - -1 for time points outside any interval + - 0, 1, 2, ... for time points inside intervals + """ + _ = pop_pos_group, pop_spikes_group # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_run_decoder", + mock_sorted_spikes_decoder, + ) + monkeypatch.setattr( + decode_v1.sorted_spikes.SortedSpikesDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_spike_params_insert, + "sorted_spikes_group_name": group_name, + "unit_filter_params_name": "default_exclusion", + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.sorted_spikes.SortedSpikesDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.sorted_spikes.SortedSpikesDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.sorted_spikes.SortedSpikesDecodingV1 & selection_key + results = table.fetch_results() + + # Verify interval_labels exists + assert "interval_labels" in results.coords, ( + "Results should have 'interval_labels' coordinate when " + "estimate_decoding_params=True" + ) + + # Get interval_labels + interval_labels = results.coords["interval_labels"].values + + # Verify interval_labels has same length as time + assert len(interval_labels) == len( + results.time + ), "interval_labels should have same length as time dimension" + + # Verify that -1 exists (for times outside intervals) + assert -1 in interval_labels, ( + "interval_labels should contain -1 for time points outside intervals " + "when estimate_decoding_params=True" + ) + + # Verify that non-negative labels exist (for times inside intervals) + assert np.any(interval_labels >= 0), ( + "interval_labels should contain non-negative values for time points " + "inside intervals" + ) + + +def test_groupby_works_with_negative_labels( + decode_v1, + monkeypatch, + mock_clusterless_decoder, + mock_decoder_save, + decode_sel_key, + group_name, + decode_clusterless_params_insert, + pop_pos_group, + group_unitwave, +): + """Test that groupby works correctly with -1 labels. + + Users should be able to: + - Group by interval_labels to iterate through intervals + - Filter out -1 labels to get only interval data + """ + _ = pop_pos_group, group_unitwave # ensure populated + + # Apply mocks + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_run_decoder", + mock_clusterless_decoder, + ) + monkeypatch.setattr( + decode_v1.clusterless.ClusterlessDecodingV1, + "_save_decoder_results", + mock_decoder_save, + ) + + # Create selection key with estimate_decoding_params=True + selection_key = { + **decode_sel_key, + **decode_clusterless_params_insert, + "waveform_features_group_name": group_name, + "estimate_decoding_params": True, + } + + # Insert and populate + decode_v1.clusterless.ClusterlessDecodingSelection.insert1( + selection_key, + skip_duplicates=True, + ) + decode_v1.clusterless.ClusterlessDecodingV1.populate(selection_key) + + # Fetch results + table = decode_v1.clusterless.ClusterlessDecodingV1 & selection_key + results = table.fetch_results() + + # Test groupby operation works + grouped = results.groupby("interval_labels") + assert grouped is not None, "Should be able to groupby interval_labels" + + # Verify we can iterate and get the -1 group + labels_seen = [] + for label, group in grouped: + labels_seen.append(label) + assert "time" in group.dims, "Each group should have time dimension" + + # Verify -1 is one of the groups + assert -1 in labels_seen, "Should have a group for -1 (outside intervals)" + + # Test filtering to only interval data + interval_data = results.where(results.interval_labels >= 0, drop=True) + assert len(interval_data.time) < len( + results.time + ), "Filtering to interval_labels >= 0 should reduce data size" + assert np.all( + interval_data.interval_labels >= 0 + ), "Filtered data should only have non-negative interval_labels" From b8b9ce0ba86fb13b52d99d1f5e1c6b8394e62122 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 21 Nov 2025 20:13:31 -0500 Subject: [PATCH 7/9] Fix test fixture to create time points outside decoding interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tests for interval_labels with estimate_decoding_params=True expected -1 labels for time points outside the decoding interval, but the fixture created decode_interval covering the same range as the position data. Changed decode_interval from [raw_begin, raw_begin+15] to [raw_begin+2, raw_begin+13] so position_info has time points outside the decoding interval, which correctly get interval_labels=-1. šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/decoding/conftest.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index d51f58ee2..b95c598f7 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -343,11 +343,15 @@ def decode_interval(common, mini_dict): raw_begin = (common.IntervalList & 'interval_list_name LIKE "raw%"').fetch1( "valid_times" )[0][0] + # Use a subset of the encoding interval (raw_begin+2 to raw_begin+13) + # This creates gaps at start and end, ensuring that when + # estimate_decoding_params=True, there are time points outside the + # decoding interval that will get interval_labels=-1 common.IntervalList.insert1( { **mini_dict, "interval_list_name": decode_interval_name, - "valid_times": [[raw_begin, raw_begin + 15]], + "valid_times": [[raw_begin + 2, raw_begin + 13]], }, skip_duplicates=True, ) From e5612ba4dcb810d9d4d18cd3686009894f909e0c Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 22 Nov 2025 07:44:30 -0500 Subject: [PATCH 8/9] Remove redundant coordinates causing xarray broadcasting error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The mock decoding results had 'states' and 'state_ind' coordinates that could conflict with the 'state' dimension, causing xarray broadcasting errors during groupby operations. Removed these redundant coordinates. šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/decoding/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index b95c598f7..24ccb9935 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -603,6 +603,8 @@ def create_fake_decoding_results(n_time=100, n_position_bins=50, n_states=2): posterior[t] /= posterior[t].sum() # Create all expected coordinates for decoding results + # Note: Avoid coordinate names that could conflict with dimension names + # (e.g., "states" vs "state") to prevent xarray broadcasting errors results = xr.Dataset( { "posterior": (["time", "position", "state"], posterior), @@ -614,9 +616,7 @@ def create_fake_decoding_results(n_time=100, n_position_bins=50, n_states=2): "state": states, "state_names": ("state", state_names), # Additional coordinates expected by tests - "states": ("state", states), # Alias for state "state_bins": ("position", position_bins), # Alias for position - "state_ind": ("state", states), # State indices "encoding_groups": ("state", state_names), # Encoding group names "environments": ("state", state_names), # Environment names }, From 0b61c6f7cd75c03877dcec58bcbd579cb64bdb2b Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 22 Nov 2025 08:41:35 -0500 Subject: [PATCH 9/9] Fix mock results to match non_local_detector output format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed dimension name from 'state' to 'states' to match the actual non_local_detector output. Also restored state_ind coordinate which is part of the real detector output. This fixes the xarray broadcasting error in groupby operations. šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/decoding/conftest.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/decoding/conftest.py b/tests/decoding/conftest.py index 24ccb9935..24d70ffc1 100644 --- a/tests/decoding/conftest.py +++ b/tests/decoding/conftest.py @@ -603,22 +603,21 @@ def create_fake_decoding_results(n_time=100, n_position_bins=50, n_states=2): posterior[t] /= posterior[t].sum() # Create all expected coordinates for decoding results - # Note: Avoid coordinate names that could conflict with dimension names - # (e.g., "states" vs "state") to prevent xarray broadcasting errors + # Note: Use "states" (plural) as the dimension name to match + # the real non_local_detector output format results = xr.Dataset( { - "posterior": (["time", "position", "state"], posterior), - "likelihood": (["time", "position", "state"], posterior * 0.8), + "posterior": (["time", "position", "states"], posterior), + "likelihood": (["time", "position", "states"], posterior * 0.8), }, coords={ "time": time, "position": position_bins, - "state": states, - "state_names": ("state", state_names), - # Additional coordinates expected by tests + "states": state_names, # Dimension coordinate with state names + "state_ind": ("states", states), # State indices "state_bins": ("position", position_bins), # Alias for position - "encoding_groups": ("state", state_names), # Encoding group names - "environments": ("state", state_names), # Environment names + "encoding_groups": ("states", state_names), # Encoding group names + "environments": ("states", state_names), # Environment names }, )