Skip to content

Commit 060cc77

Browse files
minor improvements to STIR Hessian code
- call accumulate_Hessian_times_input in the no-subset case, as this will now be a bit faster (it's also a tiny bit less code) - give correct name of parameters in cstir.h to avoid confusion - add test for objective function with prior - clarify RDP exclusion of test case
1 parent 63e9996 commit 060cc77

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

src/xSTIR/cSTIR/include/sirf/STIR/cstir.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ extern "C" {
173173
void* cSTIR_priorValue(void* ptr_p, void* ptr_i);
174174
void* cSTIR_priorGradient(void* ptr_p, void* ptr_i);
175175
void* cSTIR_priorAccumulateHessianTimesInput
176-
(void* ptr_prior, void* ptr_out, void* ptr_curr, void* ptr_inp);
176+
(void* ptr_prior, void* ptr_curr, void* ptr_inp, void* ptr_out);
177177
void* cSTIR_priorComputeHessianTimesInput
178-
(void* ptr_prior, void* ptr_out, void* ptr_cur, void* ptr_inp);
178+
(void* ptr_prior, void* ptr_cur, void* ptr_inp, void* ptr_out);
179179
void* cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g);
180180
void* cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir);
181181

src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,11 +1188,8 @@ The actual algorithm is described in
11881188
output.fill(0.0);
11891189
if (subset >= 0)
11901190
accumulate_sub_Hessian_times_input(output, curr_image_est, input, subset);
1191-
else {
1192-
for (int s = 0; s < get_num_subsets(); s++) {
1193-
accumulate_sub_Hessian_times_input(output, curr_image_est, input, s);
1194-
}
1195-
}
1191+
else
1192+
accumulate_Hessian_times_input(output, curr_image_est, input);
11961193
}
11971194
};
11981195

src/xSTIR/pSTIR/tests/test_ObjectiveFunction.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def setUp(self):
4242
am.set_up(templ,image)
4343
obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
4444
obj_fun.set_acquisition_model(am)
45+
prior = pet.QuadraticPrior()
46+
prior.set_penalisation_factor(20)
47+
obj_fun.set_prior(prior)
4548
obj_fun.set_up(image)
4649

4750
self.obj_fun = obj_fun
@@ -62,14 +65,18 @@ def test_Hessian(self, subset=-1, eps=1e-3):
6265
"""
6366
x = self.image
6467
dx = x.clone()
65-
dx *= eps/dx.norm()
68+
dx *= eps
6669
dx += eps/2
6770
y = x + dx
6871
gx = self.obj_fun.gradient(x, subset)
6972
gy = self.obj_fun.gradient(y, subset)
7073
dg = gy - gx
7174
Hdx = self.obj_fun.multiply_with_Hessian(x, dx, subset)
7275
q = (dg - Hdx).norm()/dg.norm()
76+
print('norm of (x): %f' % x.norm())
77+
print('norm of (x + dx): %f' % y.norm())
78+
print('norm of grad(x): %f' % gx.norm())
79+
print('norm of grad(x + dx): %f' % gy.norm())
7380
print('norm of grad(x + dx) - grad(x): %f' % dg.norm())
7481
print('norm of H(x)*dx: %f' % Hdx.norm())
7582
print('relative difference: %f' % q)

src/xSTIR/pSTIR/tests/tests_qp_lc_rdp.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def Hessian_test(test, prior, x, eps=1e-3):
3333
"""
3434
if x.norm() > 0:
3535
dx = x.clone()
36-
dx *= eps/dx.norm()
36+
dx *= eps
3737
dx += eps/2
3838
else:
3939
dx = x + eps
@@ -47,11 +47,17 @@ def Hessian_test(test, prior, x, eps=1e-3):
4747
q = (dg - Hdx).norm()/gynorm
4848
else:
4949
q = (dg - Hdx).norm()
50-
#print('norm of grad(x + dx) - grad(x): %f' % dg.norm())
51-
#print('norm of H(x)*dx: %f' % Hdx.norm())
52-
#print('relative difference: %g' % q)
53-
if dg.norm() == 0:
54-
q = 0
50+
# print('norm of x: %f, dx: %f' % (x.norm(), dx.norm()))
51+
# print('norm of grad(x): %f, grad(x + dx): %f' % (gx.norm(), gy.norm()))
52+
# print('norm of grad(x + dx) - grad(x): %f' % dg.norm())
53+
# print('norm of H(x)*dx: %f' % Hdx.norm())
54+
# print('relative difference: %g' % q)
55+
if issubclass(type(prior), sirf.STIR.RelativeDifferencePrior):
56+
if x.min() == x.max() and dx.min() == dx.max():
57+
# skip test in this case, as grad(x+dx) = grad(x) = 0, but H dx is not (even analytically),
58+
# although it is small.
59+
# The difficult is knowing what the tolerance is for this test in that case
60+
q = 0
5561
test.check_if_less(q, .01*eps)
5662

5763

@@ -67,6 +73,7 @@ def test_main(rec=False, verb=False, throw=True, no_ret_val=True):
6773
im_2 = im_0.get_uniform_copy(2)
6874

6975
for im in [im_0, im_1, im_2]:
76+
# print('-------------- new image (see test source) -----------------')
7077
for penalisation_factor in [0,1,4]:
7178
for kappa in [True, False]:
7279
priors = [sirf.STIR.QuadraticPrior(), sirf.STIR.LogcoshPrior(), sirf.STIR.RelativeDifferencePrior()]
@@ -97,6 +104,7 @@ def test_main(rec=False, verb=False, throw=True, no_ret_val=True):
97104

98105
if isinstance(prior, sirf.STIR.RelativeDifferencePrior):
99106
prior.set_epsilon(im.max()*.01)
107+
# print(f"penalisation_factor {penalisation_factor}, kappa {kappa}, prior {type(prior).__name__}")
100108
Hessian_test(test, prior, im, 0.03)
101109

102110
numpy.testing.assert_equal(test.failed, 0)

0 commit comments

Comments
 (0)