@@ -889,16 +889,16 @@ def det_point_position(self, index, dparam):
889889 >>> geom_2d = odl.tomo.ParallelVecGeometry(det_shape_2d, vecs_2d)
890890 >>> # This is equal to d(0) = (0, 1)
891891 >>> geom_2d.det_point_position(0, 0)
892- array([ 0., 1.])
892+ array([ 0., 1.])
893893 >>> # d(0) + 2 * u(0) = (0, 1) + 2 * (1, 0)
894894 >>> geom_2d.det_point_position(0, 2)
895- array([ 2., 1.])
895+ array([ 2., 1.])
896896 >>> # d(1) + 2 * u(1) = (-1, 0) + 2 * (0, 1)
897897 >>> geom_2d.det_point_position(1, 2)
898- array([-1., 2.])
898+ array([-1., 2.])
899899 >>> # d(0.4) + 2 * u(0.4) = d(0) + 2 * u(0)
900900 >>> geom_2d.det_point_position(0.4, 2)
901- array([ 2., 1.])
901+ array([ 2., 1.])
902902
903903 Broadcasting of arguments:
904904
@@ -935,7 +935,7 @@ def det_point_position(self, index, dparam):
935935 Broadcasting of arguments:
936936
937937 >>> idcs = np.array([0.4, 0.6])[:, None]
938- >>> dpar = np.array([2.0, 1.0])[None, : ]
938+ >>> dpar = np.array([2.0, 1.0])[:, None ]
939939 >>> geom_3d.det_point_position(idcs, dpar)
940940
941941 """
@@ -952,10 +952,10 @@ def det_point_position(self, index, dparam):
952952 '' .format (dparam , self .det_params )
953953 )
954954
955- # TODO: broadcast correctly
956955 if self .ndim == 2 :
957956 det_shift = dparam * self .det_axis (index )
958957 elif self .ndim == 3 :
958+ axes = self .det_axes (index )
959959 det_shift = sum (
960960 di * ax for di , ax in zip (dparam , self .det_axes (index ))
961961 )
@@ -1038,9 +1038,15 @@ def det_axes(self, index):
10381038
10391039 Returns
10401040 -------
1041- axes : tuple of `numpy.ndarray`, shape ``(ndim,)``
1042- The detector axes at ``index``, 1 for ``ndim == 2`` and
1043- 2 for ``ndim == 3``.
1041+ axes : `numpy.ndarray`
1042+ Unit vectors along which the detector is aligned, an array
1043+ with following shape:
1044+
1045+ - In 2D: If ``index`` is a single parameter, the shape is
1046+ ``(2,)``, otherwise ``index.shape + (2,)``.
1047+
1048+ - In 3D: If ``index`` is a single parameter, the shape is
1049+ ``(2, 3)``, otherwise ``index.shape + (2, 3)``.
10441050
10451051 Examples
10461052 --------
@@ -1056,16 +1062,17 @@ def det_axes(self, index):
10561062 >>> det_shape_3d = (10, 20)
10571063 >>> geom_3d = odl.tomo.ParallelVecGeometry(det_shape_3d, vecs_3d)
10581064 >>> geom_3d.det_axes(0)
1059- (array([ 1., 0., 0.]), array([ 0., 0., 1.]))
1065+ array([[ 1., 0., 0.],
1066+ [ 0., 0., 1.]])
10601067 >>> geom_3d.det_axes(1)
1061- ( array([ 0., 1., 0.]), array([ 0., 0., 1.]))
1062- >>> axs = geom_3d.det_axes([0.4 , 0.6]) # values at closest indices
1063- >>> axs[0] # first axis
1064- array([[ 1., 0., 0.],
1065- [ 0., 1 ., 0 .]])
1066- >>> axs[1] # second axis
1067- array( [[ 0., 0 ., 1 .],
1068- [ 0., 0., 1.]])
1068+ array([[ 0., 1., 0.],
1069+ [ 0. , 0., 1.]])
1070+ >>> geom_3d.det_axes([0.4, 0.6]) # values at closest indices
1071+ array([[[ 1., 0., 0.],
1072+ [ 0., 0 ., 1 .]],
1073+ <BLANKLINE>
1074+ [[ 0., 1 ., 0 .],
1075+ [ 0., 0., 1.] ]])
10691076 """
10701077 if (
10711078 self .check_bounds
@@ -1085,19 +1092,18 @@ def det_axes(self, index):
10851092
10861093 vectors = self .vectors [index_int ]
10871094 if self .ndim == 2 :
1088- det_us = vectors [:, self . _slice_det_u ]
1089- retval_lst = [ det_us [ 0 ]] if squeeze_index else [ det_us ]
1095+ axes = np . empty ( index_int . shape + ( 2 ,))
1096+ axes [:] = vectors [:, self . _slice_det_u ]
10901097 elif self .ndim == 3 :
1091- det_us = vectors [:, self ._slice_det_u ]
1092- det_vs = vectors [:, self ._slice_det_v ]
1093- if squeeze_index :
1094- retval_lst = [det_us [0 ], det_vs [0 ]]
1095- else :
1096- retval_lst = [det_us , det_vs ]
1098+ axes = np .empty (index_int .shape + (2 , 3 ))
1099+ axes [..., 0 , :] = vectors [..., self ._slice_det_u ]
1100+ axes [..., 1 , :] = vectors [..., self ._slice_det_v ]
10971101 else :
10981102 raise RuntimeError ('invalid `ndim`' )
10991103
1100- return tuple (retval_lst )
1104+ if squeeze_index :
1105+ axes = axes [0 ]
1106+ return axes
11011107
11021108 def __getitem__ (self , indices ):
11031109 """Return ``self[indices]``.
0 commit comments