Skip to content

Commit d307fa5

Browse files
committed
improve performance notebook
1 parent d812baa commit d307fa5

File tree

4 files changed

+1041
-375
lines changed

4 files changed

+1041
-375
lines changed

docs/ndindex_performance.ipynb

Lines changed: 748 additions & 371 deletions
Large diffs are not rendered by default.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Benchmark utilities for linked_indices performance testing.
2+
3+
This module provides helpers for rigorous performance measurement using
4+
Python's timeit module.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import timeit
10+
from typing import Any
11+
12+
import numpy as np
13+
14+
__all__ = ["timeit_benchmark"]
15+
16+
17+
def timeit_benchmark(
18+
stmt: str | callable,
19+
setup: str = "pass",
20+
globals: dict[str, Any] | None = None,
21+
repeat: int = 7,
22+
) -> dict[str, float | int]:
23+
"""
24+
Benchmark a statement or callable using timeit with automatic loop count.
25+
26+
Uses timeit's autorange to determine an appropriate number of loops,
27+
then runs multiple trials to get reliable timing statistics.
28+
29+
Parameters
30+
----------
31+
stmt : str or callable
32+
The statement or callable to benchmark.
33+
setup : str
34+
Setup code to run once before the benchmark. Default: "pass"
35+
globals : dict, optional
36+
Global namespace for the benchmark. Required when stmt uses
37+
variables from the calling scope.
38+
repeat : int
39+
Number of trials to run. Default: 7 (timeit default)
40+
41+
Returns
42+
-------
43+
dict
44+
Dictionary with timing statistics:
45+
- best_ms: Minimum time per call in milliseconds (best measure of algorithm cost)
46+
- mean_ms: Mean time per call in milliseconds (typical real-world performance)
47+
- std_ms: Standard deviation of times in milliseconds
48+
- n_loops: Number of loops per trial (determined by autorange)
49+
50+
Examples
51+
--------
52+
>>> from linked_indices.benchmark_utils import timeit_benchmark
53+
>>> import numpy as np
54+
>>> arr = np.random.randn(1000)
55+
>>> result = timeit_benchmark(
56+
... lambda: np.sum(arr),
57+
... globals={"np": np, "arr": arr}
58+
... )
59+
>>> result["best_ms"] < 1.0 # Should be fast
60+
True
61+
"""
62+
timer = timeit.Timer(stmt, setup=setup, globals=globals)
63+
64+
# Let timeit determine appropriate number of loops
65+
n_loops, _ = timer.autorange()
66+
67+
# Run multiple trials
68+
times = timer.repeat(repeat=repeat, number=n_loops)
69+
times_per_call = np.array(times) / n_loops
70+
71+
return {
72+
"best_ms": float(times_per_call.min() * 1000),
73+
"mean_ms": float(times_per_call.mean() * 1000),
74+
"std_ms": float(times_per_call.std() * 1000),
75+
"n_loops": n_loops,
76+
}

src/linked_indices/example_data.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
"multi_interval_dataset",
2626
"onset_duration_dataset",
2727
"trial_based_dataset",
28+
# NDIndex benchmark dataset generators
29+
"create_trial_ndindex_dataset",
30+
"create_diagonal_dataset",
31+
"create_radial_dataset",
32+
"create_jittered_dataset",
2833
]
2934

3035

@@ -680,3 +685,211 @@ def trial_based_dataset(
680685
)
681686

682687
return ds
688+
689+
690+
# =============================================================================
691+
# NDIndex benchmark dataset generators
692+
# =============================================================================
693+
694+
695+
def create_trial_ndindex_dataset(n_trials: int, n_times: int) -> "xr.Dataset":
696+
"""
697+
Create trial-based dataset with abs_time = trial_onset + rel_time.
698+
699+
This is the typical neuroscience use case: multiple trials with
700+
overlapping relative time but different absolute time ranges.
701+
Returns a dataset with NDIndex already set on abs_time.
702+
703+
Parameters
704+
----------
705+
n_trials : int
706+
Number of trials.
707+
n_times : int
708+
Number of time points per trial.
709+
710+
Returns
711+
-------
712+
xr.Dataset
713+
Dataset with NDIndex set on abs_time coordinate.
714+
715+
Examples
716+
--------
717+
>>> from linked_indices.example_data import create_trial_ndindex_dataset
718+
>>> ds = create_trial_ndindex_dataset(10, 100)
719+
>>> ds.sel(abs_time=0.5, method="nearest") # Select by absolute time
720+
"""
721+
import xarray as xr
722+
723+
from linked_indices import NDIndex
724+
725+
trial_onsets = np.arange(n_trials) * n_times * 0.01
726+
rel_time = np.linspace(0, n_times * 0.01, n_times)
727+
abs_time = trial_onsets[:, np.newaxis] + rel_time[np.newaxis, :]
728+
data = np.random.randn(n_trials, n_times)
729+
730+
ds = xr.Dataset(
731+
{"data": (["trial", "rel_time"], data)},
732+
coords={
733+
"trial": np.arange(n_trials),
734+
"rel_time": rel_time,
735+
"abs_time": (["trial", "rel_time"], abs_time),
736+
},
737+
)
738+
return ds.set_xindex(["abs_time"], NDIndex)
739+
740+
741+
def create_diagonal_dataset(ny: int, nx: int) -> "xr.Dataset":
742+
"""
743+
Create image-like dataset with diagonal gradient coordinate.
744+
745+
This is from the slicing gallery: derived[y, x] = y_offset[y] + x_coord[x]
746+
Similar structure to trial data but with different scale/semantics.
747+
Returns a dataset with NDIndex already set on the derived coordinate.
748+
749+
Parameters
750+
----------
751+
ny : int
752+
Number of y (row) coordinates.
753+
nx : int
754+
Number of x (column) coordinates.
755+
756+
Returns
757+
-------
758+
xr.Dataset
759+
Dataset with NDIndex set on derived coordinate.
760+
761+
Examples
762+
--------
763+
>>> from linked_indices.example_data import create_diagonal_dataset
764+
>>> ds = create_diagonal_dataset(100, 100)
765+
>>> ds.sel(derived=50, method="nearest")
766+
"""
767+
import xarray as xr
768+
769+
from linked_indices import NDIndex
770+
771+
y_coord = np.arange(ny)
772+
x_coord = np.arange(nx)
773+
774+
# Diagonal gradient: each row starts 2 units higher
775+
y_offset = y_coord * 2
776+
derived_coord = y_offset[:, np.newaxis] + x_coord[np.newaxis, :]
777+
data = np.random.randn(ny, nx)
778+
779+
ds = xr.Dataset(
780+
{"data": (["y", "x"], data)},
781+
coords={
782+
"y": y_coord,
783+
"x": x_coord,
784+
"derived": (["y", "x"], derived_coord),
785+
},
786+
)
787+
return ds.set_xindex(["derived"], NDIndex)
788+
789+
790+
def create_radial_dataset(ny: int, nx: int) -> "xr.Dataset":
791+
"""
792+
Create image-like dataset with radial coordinate (non-linear 2D).
793+
794+
This tests performance with non-monotonic, complex coordinate patterns.
795+
The radius coordinate is the distance from the center of the array.
796+
Returns a dataset with NDIndex already set on the radius coordinate.
797+
798+
Parameters
799+
----------
800+
ny : int
801+
Number of y (row) coordinates.
802+
nx : int
803+
Number of x (column) coordinates.
804+
805+
Returns
806+
-------
807+
xr.Dataset
808+
Dataset with NDIndex set on radius coordinate.
809+
810+
Examples
811+
--------
812+
>>> from linked_indices.example_data import create_radial_dataset
813+
>>> ds = create_radial_dataset(100, 100)
814+
>>> ds.sel(radius=slice(10, 20)) # Select an annulus
815+
"""
816+
import xarray as xr
817+
818+
from linked_indices import NDIndex
819+
820+
cy, cx = ny // 2, nx // 2
821+
yy, xx = np.meshgrid(np.arange(ny) - cy, np.arange(nx) - cx, indexing="ij")
822+
radius = np.sqrt(xx**2 + yy**2)
823+
data = np.random.randn(ny, nx)
824+
825+
ds = xr.Dataset(
826+
{"data": (["y", "x"], data)},
827+
coords={
828+
"y": np.arange(ny),
829+
"x": np.arange(nx),
830+
"radius": (["y", "x"], radius),
831+
},
832+
)
833+
return ds.set_xindex(["radius"], NDIndex)
834+
835+
836+
def create_jittered_dataset(
837+
n_trials: int, n_times: int, jitter_std: float = 0.1
838+
) -> "xr.Dataset":
839+
"""
840+
Create trial dataset with per-trial timing jitter.
841+
842+
More realistic: trial onsets have random variation, and sampling
843+
times have small per-sample jitter (like real physiological recordings).
844+
Returns a dataset with NDIndex already set on abs_time.
845+
846+
Parameters
847+
----------
848+
n_trials : int
849+
Number of trials.
850+
n_times : int
851+
Number of time points per trial.
852+
jitter_std : float
853+
Standard deviation of timing jitter. Default: 0.1
854+
855+
Returns
856+
-------
857+
xr.Dataset
858+
Dataset with NDIndex set on abs_time coordinate.
859+
860+
Examples
861+
--------
862+
>>> from linked_indices.example_data import create_jittered_dataset
863+
>>> ds = create_jittered_dataset(10, 100, jitter_std=0.2)
864+
>>> ds.sel(abs_time=0.5, method="nearest")
865+
"""
866+
import xarray as xr
867+
868+
from linked_indices import NDIndex
869+
870+
np.random.seed(42) # Reproducible
871+
872+
# Trial onsets with jitter
873+
base_onsets = np.arange(n_trials) * n_times * 0.01
874+
trial_onsets = base_onsets + np.random.randn(n_trials) * jitter_std
875+
trial_onsets[0] = 0 # First trial starts at 0
876+
877+
# Per-sample timing jitter within each trial
878+
base_rel_time = np.linspace(0, n_times * 0.01, n_times)
879+
rel_time_jitter = np.random.randn(n_trials, n_times) * (jitter_std * 0.01)
880+
881+
# 2D absolute time with jitter
882+
abs_time = (
883+
trial_onsets[:, np.newaxis] + base_rel_time[np.newaxis, :] + rel_time_jitter
884+
)
885+
data = np.random.randn(n_trials, n_times)
886+
887+
ds = xr.Dataset(
888+
{"data": (["trial", "rel_time"], data)},
889+
coords={
890+
"trial": np.arange(n_trials),
891+
"rel_time": base_rel_time,
892+
"abs_time": (["trial", "rel_time"], abs_time),
893+
},
894+
)
895+
return ds.set_xindex(["abs_time"], NDIndex)

src/linked_indices/nd_index.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ def _find_indices_for_value(
192192
# Find flat index of closest value
193193
flat_idx = np.argmin(np.abs(values - value))
194194
else:
195-
# Exact match required
196-
matches = np.where(values == value)
197-
if len(matches[0]) == 0:
195+
# Exact match required - use flatnonzero for efficiency
196+
flat_matches = np.flatnonzero(values == value)
197+
if len(flat_matches) == 0:
198198
raise KeyError(
199199
f"Value {value!r} not found in coordinate {coord_name!r}. "
200200
f"Use method='nearest' for approximate matching."
201201
)
202202
# Use the first match
203-
flat_idx = np.ravel_multi_index(tuple(m[0] for m in matches), values.shape)
203+
flat_idx = flat_matches[0]
204204

205205
# Convert to multi-dimensional indices
206206
indices = np.unravel_index(flat_idx, values.shape)

0 commit comments

Comments
 (0)