|
25 | 25 | "multi_interval_dataset", |
26 | 26 | "onset_duration_dataset", |
27 | 27 | "trial_based_dataset", |
| 28 | + # NDIndex benchmark dataset generators |
| 29 | + "create_trial_ndindex_dataset", |
| 30 | + "create_diagonal_dataset", |
| 31 | + "create_radial_dataset", |
| 32 | + "create_jittered_dataset", |
28 | 33 | ] |
29 | 34 |
|
30 | 35 |
|
@@ -680,3 +685,211 @@ def trial_based_dataset( |
680 | 685 | ) |
681 | 686 |
|
682 | 687 | return ds |
| 688 | + |
| 689 | + |
| 690 | +# ============================================================================= |
| 691 | +# NDIndex benchmark dataset generators |
| 692 | +# ============================================================================= |
| 693 | + |
| 694 | + |
| 695 | +def create_trial_ndindex_dataset(n_trials: int, n_times: int) -> "xr.Dataset": |
| 696 | + """ |
| 697 | + Create trial-based dataset with abs_time = trial_onset + rel_time. |
| 698 | +
|
| 699 | + This is the typical neuroscience use case: multiple trials with |
| 700 | + overlapping relative time but different absolute time ranges. |
| 701 | + Returns a dataset with NDIndex already set on abs_time. |
| 702 | +
|
| 703 | + Parameters |
| 704 | + ---------- |
| 705 | + n_trials : int |
| 706 | + Number of trials. |
| 707 | + n_times : int |
| 708 | + Number of time points per trial. |
| 709 | +
|
| 710 | + Returns |
| 711 | + ------- |
| 712 | + xr.Dataset |
| 713 | + Dataset with NDIndex set on abs_time coordinate. |
| 714 | +
|
| 715 | + Examples |
| 716 | + -------- |
| 717 | + >>> from linked_indices.example_data import create_trial_ndindex_dataset |
| 718 | + >>> ds = create_trial_ndindex_dataset(10, 100) |
| 719 | + >>> ds.sel(abs_time=0.5, method="nearest") # Select by absolute time |
| 720 | + """ |
| 721 | + import xarray as xr |
| 722 | + |
| 723 | + from linked_indices import NDIndex |
| 724 | + |
| 725 | + trial_onsets = np.arange(n_trials) * n_times * 0.01 |
| 726 | + rel_time = np.linspace(0, n_times * 0.01, n_times) |
| 727 | + abs_time = trial_onsets[:, np.newaxis] + rel_time[np.newaxis, :] |
| 728 | + data = np.random.randn(n_trials, n_times) |
| 729 | + |
| 730 | + ds = xr.Dataset( |
| 731 | + {"data": (["trial", "rel_time"], data)}, |
| 732 | + coords={ |
| 733 | + "trial": np.arange(n_trials), |
| 734 | + "rel_time": rel_time, |
| 735 | + "abs_time": (["trial", "rel_time"], abs_time), |
| 736 | + }, |
| 737 | + ) |
| 738 | + return ds.set_xindex(["abs_time"], NDIndex) |
| 739 | + |
| 740 | + |
| 741 | +def create_diagonal_dataset(ny: int, nx: int) -> "xr.Dataset": |
| 742 | + """ |
| 743 | + Create image-like dataset with diagonal gradient coordinate. |
| 744 | +
|
| 745 | + This is from the slicing gallery: derived[y, x] = y_offset[y] + x_coord[x] |
| 746 | + Similar structure to trial data but with different scale/semantics. |
| 747 | + Returns a dataset with NDIndex already set on the derived coordinate. |
| 748 | +
|
| 749 | + Parameters |
| 750 | + ---------- |
| 751 | + ny : int |
| 752 | + Number of y (row) coordinates. |
| 753 | + nx : int |
| 754 | + Number of x (column) coordinates. |
| 755 | +
|
| 756 | + Returns |
| 757 | + ------- |
| 758 | + xr.Dataset |
| 759 | + Dataset with NDIndex set on derived coordinate. |
| 760 | +
|
| 761 | + Examples |
| 762 | + -------- |
| 763 | + >>> from linked_indices.example_data import create_diagonal_dataset |
| 764 | + >>> ds = create_diagonal_dataset(100, 100) |
| 765 | + >>> ds.sel(derived=50, method="nearest") |
| 766 | + """ |
| 767 | + import xarray as xr |
| 768 | + |
| 769 | + from linked_indices import NDIndex |
| 770 | + |
| 771 | + y_coord = np.arange(ny) |
| 772 | + x_coord = np.arange(nx) |
| 773 | + |
| 774 | + # Diagonal gradient: each row starts 2 units higher |
| 775 | + y_offset = y_coord * 2 |
| 776 | + derived_coord = y_offset[:, np.newaxis] + x_coord[np.newaxis, :] |
| 777 | + data = np.random.randn(ny, nx) |
| 778 | + |
| 779 | + ds = xr.Dataset( |
| 780 | + {"data": (["y", "x"], data)}, |
| 781 | + coords={ |
| 782 | + "y": y_coord, |
| 783 | + "x": x_coord, |
| 784 | + "derived": (["y", "x"], derived_coord), |
| 785 | + }, |
| 786 | + ) |
| 787 | + return ds.set_xindex(["derived"], NDIndex) |
| 788 | + |
| 789 | + |
| 790 | +def create_radial_dataset(ny: int, nx: int) -> "xr.Dataset": |
| 791 | + """ |
| 792 | + Create image-like dataset with radial coordinate (non-linear 2D). |
| 793 | +
|
| 794 | + This tests performance with non-monotonic, complex coordinate patterns. |
| 795 | + The radius coordinate is the distance from the center of the array. |
| 796 | + Returns a dataset with NDIndex already set on the radius coordinate. |
| 797 | +
|
| 798 | + Parameters |
| 799 | + ---------- |
| 800 | + ny : int |
| 801 | + Number of y (row) coordinates. |
| 802 | + nx : int |
| 803 | + Number of x (column) coordinates. |
| 804 | +
|
| 805 | + Returns |
| 806 | + ------- |
| 807 | + xr.Dataset |
| 808 | + Dataset with NDIndex set on radius coordinate. |
| 809 | +
|
| 810 | + Examples |
| 811 | + -------- |
| 812 | + >>> from linked_indices.example_data import create_radial_dataset |
| 813 | + >>> ds = create_radial_dataset(100, 100) |
| 814 | + >>> ds.sel(radius=slice(10, 20)) # Select an annulus |
| 815 | + """ |
| 816 | + import xarray as xr |
| 817 | + |
| 818 | + from linked_indices import NDIndex |
| 819 | + |
| 820 | + cy, cx = ny // 2, nx // 2 |
| 821 | + yy, xx = np.meshgrid(np.arange(ny) - cy, np.arange(nx) - cx, indexing="ij") |
| 822 | + radius = np.sqrt(xx**2 + yy**2) |
| 823 | + data = np.random.randn(ny, nx) |
| 824 | + |
| 825 | + ds = xr.Dataset( |
| 826 | + {"data": (["y", "x"], data)}, |
| 827 | + coords={ |
| 828 | + "y": np.arange(ny), |
| 829 | + "x": np.arange(nx), |
| 830 | + "radius": (["y", "x"], radius), |
| 831 | + }, |
| 832 | + ) |
| 833 | + return ds.set_xindex(["radius"], NDIndex) |
| 834 | + |
| 835 | + |
| 836 | +def create_jittered_dataset( |
| 837 | + n_trials: int, n_times: int, jitter_std: float = 0.1 |
| 838 | +) -> "xr.Dataset": |
| 839 | + """ |
| 840 | + Create trial dataset with per-trial timing jitter. |
| 841 | +
|
| 842 | + More realistic: trial onsets have random variation, and sampling |
| 843 | + times have small per-sample jitter (like real physiological recordings). |
| 844 | + Returns a dataset with NDIndex already set on abs_time. |
| 845 | +
|
| 846 | + Parameters |
| 847 | + ---------- |
| 848 | + n_trials : int |
| 849 | + Number of trials. |
| 850 | + n_times : int |
| 851 | + Number of time points per trial. |
| 852 | + jitter_std : float |
| 853 | + Standard deviation of timing jitter. Default: 0.1 |
| 854 | +
|
| 855 | + Returns |
| 856 | + ------- |
| 857 | + xr.Dataset |
| 858 | + Dataset with NDIndex set on abs_time coordinate. |
| 859 | +
|
| 860 | + Examples |
| 861 | + -------- |
| 862 | + >>> from linked_indices.example_data import create_jittered_dataset |
| 863 | + >>> ds = create_jittered_dataset(10, 100, jitter_std=0.2) |
| 864 | + >>> ds.sel(abs_time=0.5, method="nearest") |
| 865 | + """ |
| 866 | + import xarray as xr |
| 867 | + |
| 868 | + from linked_indices import NDIndex |
| 869 | + |
| 870 | + np.random.seed(42) # Reproducible |
| 871 | + |
| 872 | + # Trial onsets with jitter |
| 873 | + base_onsets = np.arange(n_trials) * n_times * 0.01 |
| 874 | + trial_onsets = base_onsets + np.random.randn(n_trials) * jitter_std |
| 875 | + trial_onsets[0] = 0 # First trial starts at 0 |
| 876 | + |
| 877 | + # Per-sample timing jitter within each trial |
| 878 | + base_rel_time = np.linspace(0, n_times * 0.01, n_times) |
| 879 | + rel_time_jitter = np.random.randn(n_trials, n_times) * (jitter_std * 0.01) |
| 880 | + |
| 881 | + # 2D absolute time with jitter |
| 882 | + abs_time = ( |
| 883 | + trial_onsets[:, np.newaxis] + base_rel_time[np.newaxis, :] + rel_time_jitter |
| 884 | + ) |
| 885 | + data = np.random.randn(n_trials, n_times) |
| 886 | + |
| 887 | + ds = xr.Dataset( |
| 888 | + {"data": (["trial", "rel_time"], data)}, |
| 889 | + coords={ |
| 890 | + "trial": np.arange(n_trials), |
| 891 | + "rel_time": base_rel_time, |
| 892 | + "abs_time": (["trial", "rel_time"], abs_time), |
| 893 | + }, |
| 894 | + ) |
| 895 | + return ds.set_xindex(["abs_time"], NDIndex) |
0 commit comments