Skip to content

Commit 252b7d0

Browse files
committed
WIP: Fix broadcasting in vec geometry methods
1 parent 6b8eb8b commit 252b7d0

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

odl/tomo/geometry/geometry.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -889,16 +889,16 @@ def det_point_position(self, index, dparam):
889889
>>> geom_2d = odl.tomo.ParallelVecGeometry(det_shape_2d, vecs_2d)
890890
>>> # This is equal to d(0) = (0, 1)
891891
>>> geom_2d.det_point_position(0, 0)
892-
array([ 0., 1.])
892+
array([ 0., 1.])
893893
>>> # d(0) + 2 * u(0) = (0, 1) + 2 * (1, 0)
894894
>>> geom_2d.det_point_position(0, 2)
895-
array([ 2., 1.])
895+
array([ 2., 1.])
896896
>>> # d(1) + 2 * u(1) = (-1, 0) + 2 * (0, 1)
897897
>>> geom_2d.det_point_position(1, 2)
898-
array([-1., 2.])
898+
array([-1., 2.])
899899
>>> # d(0.4) + 2 * u(0.4) = d(0) + 2 * u(0)
900900
>>> geom_2d.det_point_position(0.4, 2)
901-
array([ 2., 1.])
901+
array([ 2., 1.])
902902
903903
Broadcasting of arguments:
904904
@@ -935,7 +935,7 @@ def det_point_position(self, index, dparam):
935935
Broadcasting of arguments:
936936
937937
>>> idcs = np.array([0.4, 0.6])[:, None]
938-
>>> dpar = np.array([2.0, 1.0])[None, :]
938+
>>> dpar = np.array([2.0, 1.0])[:, None]
939939
>>> geom_3d.det_point_position(idcs, dpar)
940940
941941
"""
@@ -952,10 +952,10 @@ def det_point_position(self, index, dparam):
952952
''.format(dparam, self.det_params)
953953
)
954954

955-
# TODO: broadcast correctly
956955
if self.ndim == 2:
957956
det_shift = dparam * self.det_axis(index)
958957
elif self.ndim == 3:
958+
axes = self.det_axes(index)
959959
det_shift = sum(
960960
di * ax for di, ax in zip(dparam, self.det_axes(index))
961961
)
@@ -1038,9 +1038,15 @@ def det_axes(self, index):
10381038
10391039
Returns
10401040
-------
1041-
axes : tuple of `numpy.ndarray`, shape ``(ndim,)``
1042-
The detector axes at ``index``, 1 for ``ndim == 2`` and
1043-
2 for ``ndim == 3``.
1041+
axes : `numpy.ndarray`
1042+
Unit vectors along which the detector is aligned, an array
1043+
with following shape:
1044+
1045+
- In 2D: If ``index`` is a single parameter, the shape is
1046+
``(2,)``, otherwise ``index.shape + (2,)``.
1047+
1048+
- In 3D: If ``index`` is a single parameter, the shape is
1049+
``(2, 3)``, otherwise ``index.shape + (2, 3)``.
10441050
10451051
Examples
10461052
--------
@@ -1056,16 +1062,17 @@ def det_axes(self, index):
10561062
>>> det_shape_3d = (10, 20)
10571063
>>> geom_3d = odl.tomo.ParallelVecGeometry(det_shape_3d, vecs_3d)
10581064
>>> geom_3d.det_axes(0)
1059-
(array([ 1., 0., 0.]), array([ 0., 0., 1.]))
1065+
array([[ 1., 0., 0.],
1066+
[ 0., 0., 1.]])
10601067
>>> geom_3d.det_axes(1)
1061-
(array([ 0., 1., 0.]), array([ 0., 0., 1.]))
1062-
>>> axs = geom_3d.det_axes([0.4, 0.6]) # values at closest indices
1063-
>>> axs[0] # first axis
1064-
array([[ 1., 0., 0.],
1065-
[ 0., 1., 0.]])
1066-
>>> axs[1] # second axis
1067-
array([[ 0., 0., 1.],
1068-
[ 0., 0., 1.]])
1068+
array([[ 0., 1., 0.],
1069+
[ 0., 0., 1.]])
1070+
>>> geom_3d.det_axes([0.4, 0.6]) # values at closest indices
1071+
array([[[ 1., 0., 0.],
1072+
[ 0., 0., 1.]],
1073+
<BLANKLINE>
1074+
[[ 0., 1., 0.],
1075+
[ 0., 0., 1.]]])
10691076
"""
10701077
if (
10711078
self.check_bounds
@@ -1085,19 +1092,18 @@ def det_axes(self, index):
10851092

10861093
vectors = self.vectors[index_int]
10871094
if self.ndim == 2:
1088-
det_us = vectors[:, self._slice_det_u]
1089-
retval_lst = [det_us[0]] if squeeze_index else [det_us]
1095+
axes = np.empty(index_int.shape + (2,))
1096+
axes[:] = vectors[:, self._slice_det_u]
10901097
elif self.ndim == 3:
1091-
det_us = vectors[:, self._slice_det_u]
1092-
det_vs = vectors[:, self._slice_det_v]
1093-
if squeeze_index:
1094-
retval_lst = [det_us[0], det_vs[0]]
1095-
else:
1096-
retval_lst = [det_us, det_vs]
1098+
axes = np.empty(index_int.shape + (2, 3))
1099+
axes[..., 0, :] = vectors[..., self._slice_det_u]
1100+
axes[..., 1, :] = vectors[..., self._slice_det_v]
10971101
else:
10981102
raise RuntimeError('invalid `ndim`')
10991103

1100-
return tuple(retval_lst)
1104+
if squeeze_index:
1105+
axes = axes[0]
1106+
return axes
11011107

11021108
def __getitem__(self, indices):
11031109
"""Return ``self[indices]``.

0 commit comments

Comments
 (0)