Skip to content

Commit 6a6fddf

Browse files
committed
more slicing docs
1 parent f4e14ff commit 6a6fddf

File tree

4 files changed

+690
-164
lines changed

4 files changed

+690
-164
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,5 @@ htmlcov/
3939
# OS
4040
.DS_Store
4141
examples
42+
43+
docs/expt

docs/alt-ndpoint.ipynb

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "cell-0",
6+
"metadata": {},
7+
"source": [
8+
"# NDPointIndex Approach\n",
9+
"\n",
10+
"xarray includes [`NDPointIndex`](https://xarray-indexes.readthedocs.io/blocks/ndpoint.html) for **unstructured point data** (e.g., irregular grids, scattered observations). It uses a KD-tree for spatial nearest-neighbor queries.\n",
11+
"\n",
12+
"This notebook explores whether `NDPointIndex` can solve the same problem as `NDIndex` for trial-based data with derived coordinates.\n",
13+
"\n",
14+
"## Setup"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"id": "cell-1",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"import numpy as np\n",
25+
"import xarray as xr\n",
26+
"from linked_indices.example_data import trial_based_dataset"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"id": "cell-2",
32+
"metadata": {},
33+
"source": [
34+
"## What NDPointIndex is designed for\n",
35+
"\n",
36+
"`NDPointIndex` is designed for **curvilinear grids** and **unstructured point clouds** where you have multiple coordinate variables that together define a point in N-dimensional space.\n",
37+
"\n",
38+
"The classic example is a 2D grid with latitude and longitude coordinates that vary in both dimensions:"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"id": "cell-3",
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"# Create a curvilinear grid (like ocean model output)\n",
49+
"# The lat/lon coordinates vary in BOTH dimensions\n",
50+
"shape = (5, 10)\n",
51+
"lon = xr.DataArray(np.random.uniform(-180, 180, size=shape), dims=(\"y\", \"x\"))\n",
52+
"lat = xr.DataArray(np.random.uniform(-90, 90, size=shape), dims=(\"y\", \"x\"))\n",
53+
"temperature = xr.DataArray(np.random.uniform(0, 30, size=shape), dims=(\"y\", \"x\"))\n",
54+
"\n",
55+
"ds_curvilinear = xr.Dataset(\n",
56+
" data_vars={\"temperature\": temperature}, coords={\"lon\": lon, \"lat\": lat}\n",
57+
")\n",
58+
"ds_curvilinear"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"id": "cell-4",
65+
"metadata": {},
66+
"outputs": [],
67+
"source": [
68+
"# Apply NDPointIndex - requires BOTH lon and lat together\n",
69+
"ds_indexed = ds_curvilinear.set_xindex([\"lon\", \"lat\"], xr.indexes.NDPointIndex)\n",
70+
"ds_indexed"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"id": "cell-5",
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"# Now we can query: \"Find the grid cell nearest to lat=45, lon=-120\"\n",
81+
"# This is a SPATIAL query - both coordinates together define a point\n",
82+
"ds_indexed.sel(lat=45.0, lon=-120.0, method=\"nearest\")"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"id": "cell-6",
88+
"metadata": {},
89+
"source": [
90+
"## Trying NDPointIndex with trial-based data\n",
91+
"\n",
92+
"Now let's see what happens when we try to use `NDPointIndex` with our trial-based dataset where we have a single 2D `abs_time` coordinate."
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"id": "cell-7",
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"ds = trial_based_dataset(mode=\"stacked\").drop_vars(\"trial_onset\")\n",
103+
"print(ds)"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "cell-8",
109+
"metadata": {},
110+
"source": [
111+
"### Problem 1: NDPointIndex requires matching number of variables and dimensions\n",
112+
"\n",
113+
"`NDPointIndex` expects one coordinate variable per dimension. Our `abs_time` is a single 2D variable, not two 1D variables that define points in 2D space."
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": null,
119+
"id": "cell-9",
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"# This fails! NDPointIndex expects 2 variables for 2 dimensions\n",
124+
"try:\n",
125+
" ds.set_xindex([\"abs_time\"], xr.indexes.NDPointIndex)\n",
126+
"except ValueError as e:\n",
127+
" print(f\"ValueError: {e}\")"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "cell-10",
133+
"metadata": {},
134+
"source": [
135+
"### Why this matters\n",
136+
"\n",
137+
"The fundamental difference is:\n",
138+
"\n",
139+
"| Aspect | NDPointIndex | NDIndex |\n",
140+
"|--------|--------------|----------|\n",
141+
"| **Coordinates** | Multiple 2D coords that together define position | Single N-D coord with derived values |\n",
142+
"| **Query type** | Spatial: \"find point at (x, y)\" | Value: \"find cell where value ≈ target\" |\n",
143+
"| **Use case** | Curvilinear grids, scattered observations | Structured arrays with computed coordinates |\n",
144+
"\n",
145+
"**NDPointIndex** answers: \"Which grid cell is nearest to coordinates (lat=45, lon=-120)?\"\n",
146+
"\n",
147+
"**NDIndex** answers: \"Which (trial, time) cell has `abs_time` closest to 7.5?\""
148+
]
149+
},
150+
{
151+
"cell_type": "markdown",
152+
"id": "cell-11",
153+
"metadata": {},
154+
"source": [
155+
"### Could we reshape the data to use NDPointIndex?\n",
156+
"\n",
157+
"One might try to flatten the data and treat `(trial, rel_time)` as coordinate dimensions for NDPointIndex. Let's see what that looks like:"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"id": "cell-12",
164+
"metadata": {},
165+
"outputs": [],
166+
"source": [
167+
"# Flatten the dataset to 1D\n",
168+
"ds_flat = ds.stack(point=(\"trial\", \"rel_time\"))\n",
169+
"print(f\"Original shape: {dict(ds.sizes)}\")\n",
170+
"print(f\"Flattened shape: {dict(ds_flat.sizes)}\")\n",
171+
"ds_flat"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"id": "cell-13",
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"# Create separate coordinate arrays for trial index and rel_time\n",
182+
"# to use with NDPointIndex\n",
183+
"trial_idx = xr.DataArray(np.repeat(np.arange(3), 500), dims=[\"point\"])\n",
184+
"rel_time_flat = xr.DataArray(np.tile(ds.rel_time.values, 3), dims=[\"point\"])\n",
185+
"\n",
186+
"ds_for_ndpoint = xr.Dataset(\n",
187+
" data_vars={\"data\": ([\"point\"], ds_flat.data.values)},\n",
188+
" coords={\n",
189+
" \"trial_idx\": trial_idx,\n",
190+
" \"rel_time_flat\": rel_time_flat,\n",
191+
" \"abs_time\": ([\"point\"], ds_flat.abs_time.values),\n",
192+
" },\n",
193+
")\n",
194+
"ds_for_ndpoint"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"id": "cell-14",
201+
"metadata": {},
202+
"outputs": [],
203+
"source": [
204+
"# Now we could apply NDPointIndex with trial_idx and rel_time_flat\n",
205+
"ds_ndpoint = ds_for_ndpoint.set_xindex(\n",
206+
" [\"trial_idx\", \"rel_time_flat\"], xr.indexes.NDPointIndex\n",
207+
")\n",
208+
"ds_ndpoint"
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": null,
214+
"id": "cell-15",
215+
"metadata": {},
216+
"outputs": [],
217+
"source": [
218+
"# Query: find point nearest to trial_idx=1, rel_time=2.5\n",
219+
"result = ds_ndpoint.sel(trial_idx=1, rel_time_flat=2.5, method=\"nearest\")\n",
220+
"print(\n",
221+
" f\"Found point at trial_idx={result.trial_idx.item()}, rel_time={result.rel_time_flat.item():.2f}\"\n",
222+
")\n",
223+
"print(f\"abs_time at this point: {result.abs_time.item():.2f}\")"
224+
]
225+
},
226+
{
227+
"cell_type": "markdown",
228+
"id": "cell-16",
229+
"metadata": {},
230+
"source": [
231+
"### But this doesn't solve our problem!\n",
232+
"\n",
233+
"With this approach:\n",
234+
"1. **We can't select by `abs_time` directly** - NDPointIndex uses the indexed coordinates (trial_idx, rel_time_flat), not derived values like abs_time\n",
235+
"2. **We lose the structured array** - the data is now 1D instead of (trial, rel_time)\n",
236+
"3. **We lose trial labels** - trial_idx is numeric, not the original string labels"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": null,
242+
"id": "cell-17",
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"# We CANNOT do this - abs_time is not an indexed coordinate:\n",
247+
"try:\n",
248+
" ds_ndpoint.sel(abs_time=7.5, method=\"nearest\")\n",
249+
"except KeyError as e:\n",
250+
" print(f\"KeyError: {e}\")"
251+
]
252+
},
253+
{
254+
"cell_type": "markdown",
255+
"id": "cell-18",
256+
"metadata": {},
257+
"source": [
258+
"### Could we use abs_time with KDTree directly?\n",
259+
"\n",
260+
"Another approach might be to build a KDTree on abs_time values directly. But scipy's KDTree expects points in N-dimensional space, not scalar lookups:"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "cell-19",
267+
"metadata": {},
268+
"outputs": [],
269+
"source": [
270+
"from scipy.spatial import KDTree\n",
271+
"\n",
272+
"# KDTree expects (n_points, n_dims) array\n",
273+
"# Our abs_time is shape (3, 500) = 1500 scalar values\n",
274+
"# Reshaping to (1500, 1) treats each value as a 1D point\n",
275+
"abs_time_flat = ds.abs_time.values.ravel().reshape(-1, 1)\n",
276+
"tree = KDTree(abs_time_flat)\n",
277+
"\n",
278+
"# Query for abs_time ≈ 7.5\n",
279+
"distance, flat_idx = tree.query([[7.5]])\n",
280+
"trial_idx = flat_idx[0] // 500\n",
281+
"time_idx = flat_idx[0] % 500\n",
282+
"\n",
283+
"print(\n",
284+
" f\"Found: trial={ds.trial.values[trial_idx]}, rel_time={ds.rel_time.values[time_idx]:.2f}\"\n",
285+
")\n",
286+
"print(f\"abs_time at this point: {ds.abs_time.values[trial_idx, time_idx]:.2f}\")"
287+
]
288+
},
289+
{
290+
"cell_type": "markdown",
291+
"id": "cell-20",
292+
"metadata": {},
293+
"source": [
294+
"This works, but:\n",
295+
"1. It's not integrated with xarray's indexing system\n",
296+
"2. You have to manually convert between flat indices and (trial, time) indices\n",
297+
"3. It doesn't support slices or other advanced indexing\n",
298+
"4. The data structure is lost\n",
299+
"\n",
300+
"**This is essentially what `NDIndex` does internally, but with proper xarray integration.**"
301+
]
302+
},
303+
{
304+
"cell_type": "markdown",
305+
"id": "cell-21",
306+
"metadata": {},
307+
"source": [
308+
"## Summary\n",
309+
"\n",
310+
"| Feature | NDPointIndex | NDIndex |\n",
311+
"|---------|--------------|----------|\n",
312+
"| **Use case** | Unstructured point clouds, curvilinear grids | Structured arrays with derived coordinates |\n",
313+
"| **Query type** | Spatial: find nearest (x, y) point | Value: find cell where `abs_time ≈ 7.5` |\n",
314+
"| **Coordinates** | Multiple N-D coords (one per dimension) | Single N-D coord with computed values |\n",
315+
"| **Data structure** | Points in N-D coordinate space | N-D array of scalar values |\n",
316+
"| **Returns** | Single nearest point | Dimensional slices |\n",
317+
"| **Slice support** | No | Yes (bounding box) |\n",
318+
"\n",
319+
"`NDPointIndex` and `NDIndex` solve different problems:\n",
320+
"\n",
321+
"```python\n",
322+
"# NDPointIndex: \"Find the grid cell nearest to lat=45.2, lon=-122.5\"\n",
323+
"ds.sel(lat=45.2, lon=-122.5, method=\"nearest\") # Spatial query\n",
324+
"\n",
325+
"# NDIndex: \"Find which (trial, time) has abs_time closest to 7.5\"\n",
326+
"ds.sel(abs_time=7.5, method=\"nearest\") # Value lookup in N-D array\n",
327+
"```\n",
328+
"\n",
329+
"Use `NDPointIndex` when your coordinates define positions in space (or similar multi-dimensional coordinate systems).\n",
330+
"\n",
331+
"Use `NDIndex` when you have derived coordinates computed from dimension coordinates (like `abs_time = trial_onset + rel_time`)."
332+
]
333+
}
334+
],
335+
"metadata": {
336+
"kernelspec": {
337+
"display_name": "Python 3",
338+
"language": "python",
339+
"name": "python3"
340+
},
341+
"language_info": {
342+
"name": "python",
343+
"version": "3.11.0"
344+
}
345+
},
346+
"nbformat": 4,
347+
"nbformat_minor": 5
348+
}

0 commit comments

Comments
 (0)