Skip to content

Commit a3d3a33

Browse files
yalsaffarfacebook-github-bot
authored andcommitted
Refactor to Use Vectorized Operations and Transition from NumPy to PyTorch Tensors (#406)
Summary: This PR partially solves #365, with changes as follows: - **Grid Generation and Meshgrid Handling (`dim_grid` and `get_lse_interval`)**: Transitioned from `np.mgrid` to `torch.meshgrid` and `torch.linspace`, simplifying setup and ensuring full compatibility with PyTorch, reducing conversion steps. - **Interpolation (`interpolate_monotonic`)**: Switched from `np.searchsorted` to `torch.searchsorted` and used `torch.where` for interpolation, enabling efficient, single-pass processing and maintaining overall consistency. - **Probability and Quantile Calculations (`get_lse_interval`)**: Updated to use `torch.distributions.Normal`, `torch.median`, and `torch.quantile`. - **Generalized Vectorization (`get_jnd_1d` and `get_jnd_multid`)**: Functions are now fully vectorized using PyTorch’s capabilities, avoiding element-wise iteration. Pull Request resolved: #406 Reviewed By: crasanders Differential Revision: D64563850 Pulled By: JasonKChow fbshipit-source-id: 867b86b6822eb3380a2f1e0535849b3ea44a5a05
1 parent bbda7fe commit a3d3a33

4 files changed

Lines changed: 13700 additions & 126 deletions

File tree

aepsych/models/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,10 @@ def get_jnd(
270270
return torch.tensor(1 / np.gradient(fmean, coords, axis=intensity_dim))
271271
elif method == "step":
272272
return torch.clip(
273-
torch.tensor(
274-
get_jnd_multid(
275-
fmean.detach().numpy(),
276-
coords.detach().numpy(),
277-
mono_dim=intensity_dim,
278-
)
273+
get_jnd_multid(
274+
fmean,
275+
coords,
276+
mono_dim=intensity_dim,
279277
),
280278
0,
281279
np.inf,

aepsych/plotting.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def _plot_strat_1d(
179179

180180
threshold_samps = [
181181
interpolate_monotonic(
182-
grid.squeeze().numpy(), s, target_level, strat.lb[0], strat.ub[0]
183-
)
182+
grid, s, target_level, strat.lb[0], strat.ub[0]
183+
).cpu().numpy()
184184
for s in samps
185185
]
186186
thresh_med = np.mean(threshold_samps)
@@ -201,12 +201,12 @@ def _plot_strat_1d(
201201
ax.plot(grid, true_f.squeeze(), label="True function")
202202
if target_level is not None:
203203
true_thresh = interpolate_monotonic(
204-
grid.squeeze().numpy(),
204+
grid,
205205
true_f.squeeze(),
206206
target_level,
207207
strat.lb[0],
208208
strat.ub[0],
209-
)
209+
).cpu().numpy()
210210

211211
ax.plot(
212212
true_thresh,
@@ -305,18 +305,18 @@ def _plot_strat_2d(
305305
)
306306
ax.plot(
307307
context_grid,
308-
thresh_75,
308+
thresh_75.cpu().numpy(),
309309
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)",
310310
)
311311
ax.fill_between(
312-
context_grid, lower, upper, alpha=0.3, hatch="///", edgecolor="gray"
312+
context_grid, lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.3, hatch="///", edgecolor="gray"
313313
)
314314

315315
if true_testfun is not None:
316316
true_f = true_testfun(grid).reshape(gridsize, gridsize)
317317
true_thresh = get_lse_contour(
318318
true_f, mono_grid, level=target_level, lb=strat.lb[-1], ub=strat.ub[-1]
319-
)
319+
).cpu().numpy()
320320
ax.plot(context_grid, true_thresh, label="Ground truth threshold")
321321

322322
ax.set_xlabel(xlabel)

aepsych/utils.py

Lines changed: 84 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
from collections.abc import Iterable
99
from configparser import NoOptionError
10-
from typing import Dict, List, Mapping, Optional, Tuple
10+
from typing import Dict, List, Mapping, Optional, Tuple, Union
1111

1212
import numpy as np
1313
import torch
1414
from scipy.stats import norm
1515
from torch.quasirandom import SobolEngine
1616

17+
from aepsych.config import Config
1718

1819
def make_scaled_sobol(lb, ub, size, seed=None):
1920
lb, ub, ndim = _process_bounds(lb, ub, None)
@@ -59,11 +60,11 @@ def dim_grid(
5960

6061
for i in range(dim):
6162
if i in slice_dims.keys():
62-
mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1))
63+
mesh_vals.append(torch.tensor([slice_dims[i] - 1e-10, slice_dims[i] + 1e-10]))
6364
else:
64-
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j))
65+
mesh_vals.append(torch.linspace(lower[i].item(), upper[i].item(), gridsize))
6566

66-
return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T)
67+
return torch.stack(torch.meshgrid(*mesh_vals, indexing='ij'), dim=-1).reshape(-1, dim)
6768

6869

6970
def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]:
@@ -98,95 +99,119 @@ def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]:
9899
return lb, ub, dim
99100

100101

101-
def interpolate_monotonic(x, y, z, min_x=-np.inf, max_x=np.inf):
102+
def interpolate_monotonic(x: torch.Tensor, y: torch.Tensor, z: Union[torch.Tensor, float], min_x: Union[torch.Tensor, float] =-float('inf'), max_x: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
102103
# Ben Letham's 1d interpolation code, assuming monotonicity.
103104
# basic idea is find the nearest two points to the LSE and
104105
# linearly interpolate between them (I think this is bisection
105106
# root-finding)
106-
idx = np.searchsorted(y, z)
107-
if idx == len(y):
108-
return float(max_x)
109-
elif idx == 0:
110-
return float(min_x)
107+
idx = torch.searchsorted(y, z, right=False)
108+
109+
# Handle edge cases where idx is 0 or at the end
110+
idx = torch.clamp(idx, 1, len(y) - 1)
111+
111112
x0 = x[idx - 1]
112113
x1 = x[idx]
113114
y0 = y[idx - 1]
114115
y1 = y[idx]
115116

116117
x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0)
118+
# Apply min and max boundaries
119+
x_star = torch.where(z < y[0], min_x, x_star)
120+
x_star = torch.where(z > y[-1], max_x, x_star)
121+
117122
return x_star
118123

119124

120125
def get_lse_interval(
121126
model,
122-
mono_grid,
123-
target_level,
124-
cred_level=None,
125-
mono_dim=-1,
126-
n_samps=500,
127-
lb=-np.inf,
128-
ub=np.inf,
129-
gridsize=30,
127+
mono_grid: Union[torch.Tensor, np.ndarray],
128+
target_level: float,
129+
cred_level: Optional[float]=None,
130+
mono_dim: int =-1,
131+
n_samps: int =500,
132+
lb: float =-float('inf'),
133+
ub: float =float('inf'),
134+
gridsize: int =30,
130135
**kwargs,
131-
):
132-
xgrid = torch.Tensor(
133-
np.mgrid[
134-
[
135-
slice(model.lb[i].item(), model.ub[i].item(), gridsize * 1j)
136-
for i in range(model.dim)
137-
]
138-
]
139-
.reshape(model.dim, -1)
140-
.T
141-
)
136+
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
137+
# Create a meshgrid using torch.linspace
138+
xgrid = torch.stack(
139+
torch.meshgrid(
140+
[torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize) for i in range(model.dim)]
141+
),
142+
dim=-1
143+
).reshape(-1, model.dim)
142144

143145
samps = model.sample(xgrid, num_samples=n_samps, **kwargs)
144-
samps = [s.reshape((gridsize,) * model.dim) for s in samps.detach().numpy()]
145-
contours = np.stack(
146+
samps = [s.reshape((gridsize,) * model.dim) for s in samps]
147+
148+
# Define the normal distribution for the CDF
149+
normal_dist = torch.distributions.Normal(0, 1)
150+
151+
# Calculate contours using torch.stack and the torch CDF for each sample
152+
contours = torch.stack(
146153
[
147-
get_lse_contour(norm.cdf(s), mono_grid, target_level, mono_dim, lb, ub)
154+
get_lse_contour(normal_dist.cdf(s), mono_grid, target_level, mono_dim, lb, ub)
148155
for s in samps
149156
]
150157
)
151158

152159
if cred_level is None:
153-
return np.mean(contours, 0.5, axis=0)
160+
return torch.median(contours, dim=0).values
154161
else:
155162
alpha = 1 - cred_level
156163
qlower = alpha / 2
157164
qupper = 1 - alpha / 2
158165

159-
upper = np.quantile(contours, qupper, axis=0)
160-
lower = np.quantile(contours, qlower, axis=0)
161-
median = np.quantile(contours, 0.5, axis=0)
166+
lower = torch.quantile(contours, qlower, dim=0)
167+
upper = torch.quantile(contours, qupper, dim=0)
168+
median = torch.quantile(contours, 0.5, dim=0)
162169

163170
return median, lower, upper
164171

165172

166-
def get_lse_contour(post_mean, mono_grid, level, mono_dim=-1, lb=-np.inf, ub=np.inf):
167-
return np.apply_along_axis(
168-
lambda p: interpolate_monotonic(mono_grid, p, level, lb, ub),
169-
mono_dim,
170-
post_mean,
171-
)
172-
173-
174-
def get_jnd_1d(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf):
173+
def get_lse_contour(post_mean: torch.Tensor, mono_grid: Union[torch.Tensor, np.ndarray], level: float, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
174+
post_mean = torch.tensor(post_mean, dtype=torch.float32)
175+
mono_grid = torch.tensor(mono_grid, dtype=torch.float32)
176+
177+
# Move mono_dim to the last dimension if it isn't already
178+
if mono_dim != -1:
179+
post_mean = post_mean.transpose(mono_dim, -1)
180+
181+
# Apply interpolation across all rows at once
182+
result = interpolate_monotonic(mono_grid, post_mean, level, lb, ub)
183+
184+
# Transpose back if necessary
185+
if mono_dim != -1:
186+
result = result.transpose(-1, mono_dim)
187+
188+
return result
189+
190+
191+
def get_jnd_1d(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
192+
193+
# Calculate interpolate_to in a vectorized way
175194
interpolate_to = post_mean + df
176-
return (
177-
np.array(
178-
[interpolate_monotonic(mono_grid, post_mean, ito) for ito in interpolate_to]
179-
)
180-
- mono_grid
181-
)
182-
183-
184-
def get_jnd_multid(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf):
185-
return np.apply_along_axis(
186-
lambda p: get_jnd_1d(p, mono_grid, df=df, mono_dim=mono_dim, lb=lb, ub=ub),
187-
mono_dim,
188-
post_mean,
189-
)
195+
196+
# Apply interpolation to the entire tensor
197+
interpolated_values = interpolate_monotonic(mono_grid, post_mean, interpolate_to, lb, ub)
198+
199+
return interpolated_values - mono_grid
200+
201+
def get_jnd_multid(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
202+
203+
# Move mono_dim to the last dimension if it isn't already
204+
if mono_dim != -1:
205+
post_mean = post_mean.transpose(mono_dim, -1)
206+
207+
# Apply get_jnd_1d in a vectorized way
208+
result = get_jnd_1d(post_mean, mono_grid, df=df, mono_dim=-1, lb=lb, ub=ub)
209+
210+
# Transpose back if necessary
211+
if mono_dim != -1:
212+
result = result.transpose(-1, mono_dim)
213+
214+
return result
190215

191216

192217
def _get_ax_parameters(config):

pubs/owenetal/code/test_functions.ipynb

Lines changed: 13605 additions & 54 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)