|
7 | 7 |
|
8 | 8 | from collections.abc import Iterable |
9 | 9 | from configparser import NoOptionError |
10 | | -from typing import Dict, List, Mapping, Optional, Tuple |
| 10 | +from typing import Dict, List, Mapping, Optional, Tuple, Union |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | 13 | import torch |
14 | 14 | from scipy.stats import norm |
15 | 15 | from torch.quasirandom import SobolEngine |
16 | 16 |
|
| 17 | +from aepsych.config import Config |
17 | 18 |
|
18 | 19 | def make_scaled_sobol(lb, ub, size, seed=None): |
19 | 20 | lb, ub, ndim = _process_bounds(lb, ub, None) |
@@ -59,11 +60,11 @@ def dim_grid( |
59 | 60 |
|
60 | 61 | for i in range(dim): |
61 | 62 | if i in slice_dims.keys(): |
62 | | - mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1)) |
| 63 | + mesh_vals.append(torch.tensor([slice_dims[i] - 1e-10, slice_dims[i] + 1e-10])) |
63 | 64 | else: |
64 | | - mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j)) |
| 65 | + mesh_vals.append(torch.linspace(lower[i].item(), upper[i].item(), gridsize)) |
65 | 66 |
|
66 | | - return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T) |
| 67 | + return torch.stack(torch.meshgrid(*mesh_vals, indexing='ij'), dim=-1).reshape(-1, dim) |
67 | 68 |
|
68 | 69 |
|
69 | 70 | def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]: |
@@ -98,95 +99,119 @@ def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]: |
98 | 99 | return lb, ub, dim |
99 | 100 |
|
100 | 101 |
|
101 | | -def interpolate_monotonic(x, y, z, min_x=-np.inf, max_x=np.inf): |
| 102 | +def interpolate_monotonic(x: torch.Tensor, y: torch.Tensor, z: Union[torch.Tensor, float], min_x: Union[torch.Tensor, float] =-float('inf'), max_x: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor: |
102 | 103 | # Ben Letham's 1d interpolation code, assuming monotonicity. |
103 | 104 | # basic idea is find the nearest two points to the LSE and |
104 | 105 | # linearly interpolate between them (I think this is bisection |
105 | 106 | # root-finding) |
106 | | - idx = np.searchsorted(y, z) |
107 | | - if idx == len(y): |
108 | | - return float(max_x) |
109 | | - elif idx == 0: |
110 | | - return float(min_x) |
| 107 | + idx = torch.searchsorted(y, z, right=False) |
| 108 | + |
| 109 | + # Handle edge cases where idx is 0 or at the end |
| 110 | + idx = torch.clamp(idx, 1, len(y) - 1) |
| 111 | + |
111 | 112 | x0 = x[idx - 1] |
112 | 113 | x1 = x[idx] |
113 | 114 | y0 = y[idx - 1] |
114 | 115 | y1 = y[idx] |
115 | 116 |
|
116 | 117 | x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0) |
| 118 | + # Apply min and max boundaries |
| 119 | + x_star = torch.where(z < y[0], min_x, x_star) |
| 120 | + x_star = torch.where(z > y[-1], max_x, x_star) |
| 121 | + |
117 | 122 | return x_star |
118 | 123 |
|
119 | 124 |
|
120 | 125 | def get_lse_interval( |
121 | 126 | model, |
122 | | - mono_grid, |
123 | | - target_level, |
124 | | - cred_level=None, |
125 | | - mono_dim=-1, |
126 | | - n_samps=500, |
127 | | - lb=-np.inf, |
128 | | - ub=np.inf, |
129 | | - gridsize=30, |
| 127 | + mono_grid: Union[torch.Tensor, np.ndarray], |
| 128 | + target_level: float, |
| 129 | + cred_level: Optional[float]=None, |
| 130 | + mono_dim: int =-1, |
| 131 | + n_samps: int =500, |
| 132 | + lb: float =-float('inf'), |
| 133 | + ub: float =float('inf'), |
| 134 | + gridsize: int =30, |
130 | 135 | **kwargs, |
131 | | -): |
132 | | - xgrid = torch.Tensor( |
133 | | - np.mgrid[ |
134 | | - [ |
135 | | - slice(model.lb[i].item(), model.ub[i].item(), gridsize * 1j) |
136 | | - for i in range(model.dim) |
137 | | - ] |
138 | | - ] |
139 | | - .reshape(model.dim, -1) |
140 | | - .T |
141 | | - ) |
| 136 | +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: |
| 137 | + # Create a meshgrid using torch.linspace |
| 138 | + xgrid = torch.stack( |
| 139 | + torch.meshgrid( |
| 140 | + [torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize) for i in range(model.dim)] |
| 141 | + ), |
| 142 | + dim=-1 |
| 143 | + ).reshape(-1, model.dim) |
142 | 144 |
|
143 | 145 | samps = model.sample(xgrid, num_samples=n_samps, **kwargs) |
144 | | - samps = [s.reshape((gridsize,) * model.dim) for s in samps.detach().numpy()] |
145 | | - contours = np.stack( |
| 146 | + samps = [s.reshape((gridsize,) * model.dim) for s in samps] |
| 147 | + |
| 148 | + # Define the normal distribution for the CDF |
| 149 | + normal_dist = torch.distributions.Normal(0, 1) |
| 150 | + |
| 151 | + # Calculate contours using torch.stack and the torch CDF for each sample |
| 152 | + contours = torch.stack( |
146 | 153 | [ |
147 | | - get_lse_contour(norm.cdf(s), mono_grid, target_level, mono_dim, lb, ub) |
| 154 | + get_lse_contour(normal_dist.cdf(s), mono_grid, target_level, mono_dim, lb, ub) |
148 | 155 | for s in samps |
149 | 156 | ] |
150 | 157 | ) |
151 | 158 |
|
152 | 159 | if cred_level is None: |
153 | | - return np.mean(contours, 0.5, axis=0) |
| 160 | + return torch.median(contours, dim=0).values |
154 | 161 | else: |
155 | 162 | alpha = 1 - cred_level |
156 | 163 | qlower = alpha / 2 |
157 | 164 | qupper = 1 - alpha / 2 |
158 | 165 |
|
159 | | - upper = np.quantile(contours, qupper, axis=0) |
160 | | - lower = np.quantile(contours, qlower, axis=0) |
161 | | - median = np.quantile(contours, 0.5, axis=0) |
| 166 | + lower = torch.quantile(contours, qlower, dim=0) |
| 167 | + upper = torch.quantile(contours, qupper, dim=0) |
| 168 | + median = torch.quantile(contours, 0.5, dim=0) |
162 | 169 |
|
163 | 170 | return median, lower, upper |
164 | 171 |
|
165 | 172 |
|
166 | | -def get_lse_contour(post_mean, mono_grid, level, mono_dim=-1, lb=-np.inf, ub=np.inf): |
167 | | - return np.apply_along_axis( |
168 | | - lambda p: interpolate_monotonic(mono_grid, p, level, lb, ub), |
169 | | - mono_dim, |
170 | | - post_mean, |
171 | | - ) |
172 | | - |
173 | | - |
174 | | -def get_jnd_1d(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf): |
| 173 | +def get_lse_contour(post_mean: torch.Tensor, mono_grid: Union[torch.Tensor, np.ndarray], level: float, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor: |
| 174 | + post_mean = torch.tensor(post_mean, dtype=torch.float32) |
| 175 | + mono_grid = torch.tensor(mono_grid, dtype=torch.float32) |
| 176 | + |
| 177 | + # Move mono_dim to the last dimension if it isn't already |
| 178 | + if mono_dim != -1: |
| 179 | + post_mean = post_mean.transpose(mono_dim, -1) |
| 180 | + |
| 181 | + # Apply interpolation across all rows at once |
| 182 | + result = interpolate_monotonic(mono_grid, post_mean, level, lb, ub) |
| 183 | + |
| 184 | + # Transpose back if necessary |
| 185 | + if mono_dim != -1: |
| 186 | + result = result.transpose(-1, mono_dim) |
| 187 | + |
| 188 | + return result |
| 189 | + |
| 190 | + |
| 191 | +def get_jnd_1d(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor: |
| 192 | + |
| 193 | + # Calculate interpolate_to in a vectorized way |
175 | 194 | interpolate_to = post_mean + df |
176 | | - return ( |
177 | | - np.array( |
178 | | - [interpolate_monotonic(mono_grid, post_mean, ito) for ito in interpolate_to] |
179 | | - ) |
180 | | - - mono_grid |
181 | | - ) |
182 | | - |
183 | | - |
184 | | -def get_jnd_multid(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf): |
185 | | - return np.apply_along_axis( |
186 | | - lambda p: get_jnd_1d(p, mono_grid, df=df, mono_dim=mono_dim, lb=lb, ub=ub), |
187 | | - mono_dim, |
188 | | - post_mean, |
189 | | - ) |
| 195 | + |
| 196 | + # Apply interpolation to the entire tensor |
| 197 | + interpolated_values = interpolate_monotonic(mono_grid, post_mean, interpolate_to, lb, ub) |
| 198 | + |
| 199 | + return interpolated_values - mono_grid |
| 200 | + |
| 201 | +def get_jnd_multid(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor: |
| 202 | + |
| 203 | + # Move mono_dim to the last dimension if it isn't already |
| 204 | + if mono_dim != -1: |
| 205 | + post_mean = post_mean.transpose(mono_dim, -1) |
| 206 | + |
| 207 | + # Apply get_jnd_1d in a vectorized way |
| 208 | + result = get_jnd_1d(post_mean, mono_grid, df=df, mono_dim=-1, lb=lb, ub=ub) |
| 209 | + |
| 210 | + # Transpose back if necessary |
| 211 | + if mono_dim != -1: |
| 212 | + result = result.transpose(-1, mono_dim) |
| 213 | + |
| 214 | + return result |
190 | 215 |
|
191 | 216 |
|
192 | 217 | def _get_ax_parameters(config): |
|
0 commit comments