Skip to content

Commit 0e02abb

Browse files
committed
fully revert tensor changes in dim_grid (#460)
Summary: Pull Request resolved: #460 dim_grid tensor changes were only half reverted for some reason. This fully reverts it and fixes a bug where dim_grid was producing double the number of points when slicing. Reviewed By: crasanders Differential Revision: D66374663 fbshipit-source-id: c2226b59c6320bedf74c3374f6ae1347f47e4e26
1 parent ba83b92 commit 0e02abb

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

aepsych/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,11 @@ def dim_grid(
6363

6464
for i in range(dim):
6565
if i in slice_dims.keys():
66-
mesh_vals.append(
67-
torch.tensor([slice_dims[i] - 1e-10, slice_dims[i] + 1e-10])
68-
)
66+
mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1))
6967
else:
70-
mesh_vals.append(torch.linspace(lower[i].item(), upper[i].item(), gridsize))
68+
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j))
7169

72-
return torch.stack(torch.meshgrid(*mesh_vals, indexing="ij"), dim=-1).reshape(
73-
-1, dim
74-
)
70+
return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T)
7571

7672

7773
def _process_bounds(

tests/test_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
import torch
1212
from aepsych.models import GPClassificationModel
13-
from aepsych.utils import _process_bounds, make_scaled_sobol
13+
from aepsych.utils import _process_bounds, dim_grid, make_scaled_sobol
1414

1515

1616
class UtilsTestCase(unittest.TestCase):
@@ -35,6 +35,13 @@ def test_dim_grid_model_size(self):
3535
grid = GPClassificationModel.dim_grid(mb, gridsize=gridsize)
3636
self.assertEqual(grid.shape, torch.Size([10, 1]))
3737

38+
def test_dim_grid_slice(self):
39+
lb = torch.tensor([0, 0, 0])
40+
ub = torch.tensor([1, 1, 1])
41+
grid = dim_grid(lb, ub, slice_dims={1: 0.5})
42+
43+
self.assertTrue(np.all(grid.shape == (900, 3)))
44+
3845
def test_process_bounds(self):
3946
lb, ub, dim = _process_bounds(np.r_[0, 1], np.r_[2, 3], None)
4047
self.assertTrue(torch.all(lb == torch.tensor([0.0, 1.0])))

0 commit comments

Comments
 (0)