Skip to content

Commit 41aad6f

Browse files
Extends SIRF data algebra possible operands to numpy arrays (#1358)
* added numpy array to possible SIRF image data algebra operands (on the "right") * added to User Guide a subsection on SIRF/numpy data algebra peculiarities * updated CHANGES.md * added tests
1 parent 8107117 commit 41aad6f

File tree

5 files changed

+108
-3
lines changed

5 files changed

+108
-3
lines changed

CHANGES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# ChangeLog
22

3+
## v3.9.1
4+
5+
* Python interface
6+
- Restored functionality for algebraic operations mixing STIR.ImageData and numpy arrays. (Note that sirf objects need to be on the "left" of the operation.)
7+
38
## v3.9.0
49

510
* Python interface

doc/UserGuide.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ The mutators are also responsible for basic error checking.
108108

109109
Some classes are _derived_ from other classes, which means that they have (_inherit_) all the methods of the classes they are derived from. If class B derives from class A, then A is called its _base_ class. <!-- we say that these _derived_ class methods are _inherited_ from the _base_ class.--> For example, class `AcquisitionModelUsingRayTracingMatrix` is derived from `AcquisitionModelUsingMatrix`, which in turn is derived from `AcquisitionModel`, and so it inherits all the methods of the latter two base classes.
110110

111+
### SIRF data algebra <a name="SIRF_data_algebra"></a>
112+
113+
SIRF Python interface supports algebraic operations (`+`, `-`, `*` and `/`): e.g. elements of the data array stored in the object `a*b` are the products of the respective elements in `a` and `b`. Either or both `a` and `b` can be SIRF data objects of the same kind (either both `ImageData` or both `AcquisitionData`) or `numpy` arrays or scalars. One should be aware though that if `a` is a SIRF object then, just as one would expect, the product `a*b` will be a SIRF object of the same kind, but if `a` is a `numpy` object (array or scalar) then Python will try to convert `b` to a `numpy` object before computing `a*b`, and only if this fails it will compute `b*a` instead. To avoid confusion, the users are advised to check the type of `a*b` or, better still, always place a SIRF object on the left side of the product.
114+
111115
### Error handling <a name="error_handling"></a>
112116

113117
Error handling is via exceptions, i.e. functions do not return an error status, but throw an error if something did not work. The user can catch these exceptions if required as illustrated in the demos.

src/Registration/pReg/Reg.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import sirf.Reg_params as parms
3434
import numpy
35+
from numbers import Number
3536
from sirf.config import SIRF_HAS_SPM
3637

3738
if sys.version_info[0] >= 3 and sys.version_info[1] >= 4:
@@ -264,6 +265,64 @@ def get_voxel_sizes(self):
264265
self.handle, out.ctypes.data))
265266
return out
266267

268+
def dot(self, other):
269+
'''
270+
Returns the dot product of the container data with another container
271+
data or numpy array viewed as vectors.
272+
other: NiftiImageData or numpy array.
273+
'''
274+
if not (issubclass(type(other), type(self))):
275+
self_copy = self.clone()
276+
self_copy.fill(other)
277+
other = self_copy
278+
return super().dot(other)
279+
280+
def add(self, other, out=None):
281+
'''
282+
Addition for NiftiImageData containers.
283+
284+
If other is a NiftiData or numpy array, returns the sum of data
285+
stored in self and other viewed as vectors.
286+
If other is a scalar, returns the same with the second vector filled
287+
with the value of other.
288+
other: NiftiImageData or numpy array or scalar.
289+
out: NiftiImageData to store the result to.
290+
'''
291+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
292+
self_copy = self.clone()
293+
self_copy.fill(other)
294+
other = self_copy
295+
return super().add(other, out)
296+
297+
def subtract(self, other, out=None):
298+
'''
299+
Subtraction for NiftiImageData containers.
300+
301+
If other is a NiftiData or numpy array, returns the sum of data
302+
stored in self and other viewed as vectors.
303+
If other is a scalar, returns the same with the second vector filled
304+
with the value of other.
305+
other: NiftiImageData or numpy array or scalar.
306+
out: NiftiImageData to store the result to.
307+
'''
308+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
309+
self_copy = self.clone()
310+
self_copy.fill(other)
311+
other = self_copy
312+
return super().subtract(other, out)
313+
314+
def binary(self, other, f, out=None):
315+
'''Applies function f(x,y) element-wise to self and other.
316+
317+
other: NiftiImageData or numpy array or Number
318+
f: the name of the function to apply, Python str.
319+
'''
320+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
321+
self_copy = self.clone()
322+
self_copy.fill(other)
323+
other = self_copy
324+
return super().binary(other, f, out)
325+
267326
def fill(self, val):
268327
"""Fill image with single value or numpy array."""
269328
if self.handle is None:

src/common/SIRF.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ def __sub__(self, other):
139139
data viewed as vectors.
140140
other: DataContainer
141141
'''
142-
if isinstance(other, (DataContainer, Number) ):
143-
return self.subtract(other)
144-
return NotImplemented
142+
return self.subtract(other)
145143

146144
def __mul__(self, other):
147145
'''
@@ -331,6 +329,8 @@ def dot(self, other):
331329
data viewed as vectors.
332330
other: DataContainer
333331
'''
332+
if not (issubclass(type(other), type(self))):
333+
other = self.clone().fill(other)
334334
assert_validities(self, other)
335335
# Check if input are the same size
336336
if self.size != other.size:
@@ -383,6 +383,8 @@ def add(self, other, out=None):
383383
other: DataContainer or scalar.
384384
out: DataContainer to store the result to.
385385
'''
386+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
387+
other = self.clone().fill(other)
386388
if out is None:
387389
z = self.same_object()
388390
else:
@@ -416,10 +418,14 @@ def subtract(self, other, out=None):
416418
other: DataContainer or scalar.
417419
other: DataContainer
418420
'''
421+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
422+
other = self.clone().fill(other)
419423
if not isinstance (other, (DataContainer, Number)):
420424
return NotImplemented
421425
if isinstance(other, Number):
422426
return self.add(-other, out=out)
427+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
428+
other = self.clone().fill(other)
423429
assert_validities(self, other)
424430
pl_one = numpy.asarray([1.0, 0.0], dtype = numpy.float32)
425431
mn_one = numpy.asarray([-1.0, 0.0], dtype = numpy.float32)
@@ -625,6 +631,8 @@ def binary(self, other, f, out=None):
625631
try_calling(pysirf.cSIRF_compute_semibinary(self.handle, y.ctypes.data, \
626632
f, out.handle))
627633
else:
634+
if not (issubclass(type(other), type(self)) or isinstance(other, (Number, numpy.number))):
635+
other = self.clone().fill(other)
628636
assert_validities(self, other)
629637
if out.handle is None:
630638
out.handle = pysirf.cSIRF_binary(self.handle, other.handle, f)

src/common/Utilities.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,11 @@ def data_container_algebra_tests(test, x, eps=1e-4):
742742
# needs increased tolerance for large data size
743743
test.check_if_equal_within_tolerance(t, s, 0, eps * 10);
744744

745+
s = x.dot(ay)
746+
t = numpy.vdot(ay, ax)
747+
# needs increased tolerance for large data size
748+
test.check_if_equal_within_tolerance(t, s, 0, eps * 10);
749+
745750
x2 = x.multiply(2)
746751
ax2 = x2.as_array()
747752
s = numpy.linalg.norm(ax2 - 2*ax)
@@ -779,12 +784,30 @@ def data_container_algebra_tests(test, x, eps=1e-4):
779784
t = numpy.linalg.norm(az)
780785
test.check_if_zero_within_tolerance(s, eps * t)
781786

787+
z = x*ay
788+
az = z.as_array()
789+
s = numpy.linalg.norm(az - ax * ay)
790+
t = numpy.linalg.norm(az)
791+
test.check_if_zero_within_tolerance(s, eps * t)
792+
782793
y = x + 1
783794
ay = y.as_array()
784795
s = numpy.linalg.norm(ay - (ax + 1))
785796
t = numpy.linalg.norm(ay)
786797
test.check_if_zero_within_tolerance(s, eps * t)
787798

799+
y = x + ax
800+
ay = y.as_array()
801+
s = numpy.linalg.norm(ay - (ax + ax))
802+
t = numpy.linalg.norm(ay)
803+
test.check_if_zero_within_tolerance(s, eps * t)
804+
805+
t = numpy.linalg.norm(ay)
806+
y = x - ax
807+
ay = y.as_array()
808+
s = numpy.linalg.norm(ay)
809+
test.check_if_zero_within_tolerance(s, eps * t)
810+
788811
y *= 0
789812
x.add(1, out=y)
790813
ay = y.as_array()
@@ -798,6 +821,12 @@ def data_container_algebra_tests(test, x, eps=1e-4):
798821
t = numpy.linalg.norm(az)
799822
test.check_if_zero_within_tolerance(s, eps * t)
800823

824+
z = x/ay
825+
az = z.as_array()
826+
s = numpy.linalg.norm(az - ax/ay)
827+
t = numpy.linalg.norm(az)
828+
test.check_if_zero_within_tolerance(s, eps * t)
829+
801830
z = x/2
802831
az = z.as_array()
803832
s = numpy.linalg.norm(az - ax/2)

0 commit comments

Comments
 (0)