Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing Hessian products with image data objects and related stuff #1253

Merged
merged 21 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6816fc0
implemented prior.accumulate_Hessian_times_input()
evgueni-ovtchinnikov May 2, 2024
956c498
implemented return of Hessian*x via out= in Prior
evgueni-ovtchinnikov May 9, 2024
2390735
merged master
evgueni-ovtchinnikov May 9, 2024
99ab3c6
implemented ObjectiveFunction.accumulate_Hessian_times_input (not tes…
evgueni-ovtchinnikov May 9, 2024
c14da39
implemented set_time_interval for listmode objective function (not te…
evgueni-ovtchinnikov May 10, 2024
f5ff1ba
used auto and corrected inheritance of listmode objective function in…
evgueni-ovtchinnikov May 10, 2024
66c35c7
fixed typo in listmode objective function class in STIR.py
evgueni-ovtchinnikov May 11, 2024
14bf4f6
use references in Hessian
KrisThielemans May 12, 2024
dbf292a
added set_subsensitivity_filenames method to listmode objective function
evgueni-ovtchinnikov May 13, 2024
8c9b1fe
Merge branch 'acc-hess' of https://github.com/SyneRBI/SIRF into acc-hess
evgueni-ovtchinnikov May 13, 2024
0bb6619
implemented multiply_with_Hessian, checked by simple regression test
evgueni-ovtchinnikov May 15, 2024
136add7
attended to Codacy issue
evgueni-ovtchinnikov May 15, 2024
7f24a44
added objective function method that checks that grad(x+dx)-grad(x)~H…
evgueni-ovtchinnikov May 15, 2024
4807871
attended to Codacy issues
evgueni-ovtchinnikov May 15, 2024
3b2dffe
resolved reviewer's issues
evgueni-ovtchinnikov May 16, 2024
608631c
updated Hessian multiplication for priors
evgueni-ovtchinnikov May 16, 2024
16547a5
corrected Prior.multiply_with_Hessian()
evgueni-ovtchinnikov May 16, 2024
32167f6
add test on Hessian to CI
KrisThielemans May 16, 2024
2dbcacc
add test for Hessian of prior
KrisThielemans May 16, 2024
091de32
attended to Codacy issues
evgueni-ovtchinnikov May 16, 2024
0531235
updated CHANGES.md
evgueni-ovtchinnikov May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
- implemented basic-functionality listmode data class in C++ and Python
- added objective function type for lismode reconstruction
- added new demo script for the reconstruction from listmode data
- provided gradient-computing methods with return via optional argument out
in addition to the standard return, ensuring that no temorary copies of the
gradient data are created
- provided prior and objective function objects with methods for computing
the product of the Hessian and a vector

## v3.6.0

Expand Down
101 changes: 97 additions & 4 deletions src/xSTIR/cSTIR/cstir.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
SyneRBI Synergistic Image Reconstruction Framework (SIRF)
Copyright 2017 - 2023 Rutherford Appleton Laboratory STFC
Copyright 2017 - 2024 Rutherford Appleton Laboratory STFC
Copyright 2018 - 2024 University College London.

This is software developed for the Collaborative Computational
Expand Down Expand Up @@ -417,12 +417,25 @@ void* cSTIR_objectFromFile(const char* name, const char* filename)
CATCH;
}

typedef xSTIR_PoissonLLhLinModMeanListDataProjMatBin3DF LMObjFun;

extern "C"
void* cSTIR_objFunListModeSetInterval(void* ptr_f, size_t ptr_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really for a question for this PR, but why size_t ptr_data and not void * ptr_data? On some systems, they are not the same size, and this could therefore create trouble.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember having some SWIG trouble with void* but no longer remember what it was. Perhaps with the latest SWIG void* is now ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW SWIG actually still does not like passing the numpy array data pointer as void* - tried with cSTIR_getImageData, got this:

Traceback (most recent call last):
  File "/home/sirfuser/devel/buildVM/sources/SIRF/examples/Python/PET/acquisition_data.py", line 172, in <module>
    main()
  File "/home/sirfuser/devel/buildVM/sources/SIRF/examples/Python/PET/acquisition_data.py", line 141, in main
    image_array = image.as_array()
  File "/home/sirfuser/devel/install/python/sirf/STIR.py", line 611, in as_array
    try_calling(pystir.cSTIR_getImageData(self.handle, array.ctypes.data))
  File "/home/sirfuser/devel/install/python/sirf/pystir.py", line 306, in cSTIR_getImageData
    return _pystir.cSTIR_getImageData(ptr, ptr_data)
TypeError: in method 'cSTIR_getImageData', argument 2 of type 'void *'
sirfuser@vagrant:~/devel/buildVM/sources/SIRF/examples/Python/PET$ 

same story with float*:

Traceback (most recent call last):
  File "/home/sirfuser/devel/buildVM/sources/SIRF/examples/Python/PET/acquisition_data.py", line 172, in <module>
    main()
  File "/home/sirfuser/devel/buildVM/sources/SIRF/examples/Python/PET/acquisition_data.py", line 141, in main
    image_array = image.as_array()
  File "/home/sirfuser/devel/install/python/sirf/STIR.py", line 611, in as_array
    try_calling(pystir.cSTIR_getImageData(self.handle, array.ctypes.data))
  File "/home/sirfuser/devel/install/python/sirf/pystir.py", line 306, in cSTIR_getImageData
    return _pystir.cSTIR_getImageData(ptr, ptr_data)
TypeError: in method 'cSTIR_getImageData', argument 2 of type 'float *'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm. weird. Here's an example that does that https://stackoverflow.com/a/37308401. Anyway, let's leave that for later!

{
try {
auto& objFun = objectFromHandle<LMObjFun>(ptr_f);
float* data = (float*)ptr_data;
objFun.set_time_interval((double)data[0], (double)data[1]);
return (void*)new DataHandle;
}
CATCH;
}

extern "C"
void* cSTIR_setListmodeToSinogramsInterval(void* ptr_lm2s, size_t ptr_data)
{
try {
ListmodeToSinograms& lm2s =
objectFromHandle<ListmodeToSinograms>(ptr_lm2s);
auto& lm2s = objectFromHandle<ListmodeToSinograms>(ptr_lm2s);
float *data = (float *)ptr_data;
lm2s.set_time_interval((double)data[0], (double)data[1]);
return (void*)new DataHandle;
Expand Down Expand Up @@ -1249,6 +1262,50 @@ cSTIR_computeObjectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int s
CATCH;
}

extern "C"
void*
cSTIR_objectiveFunctionAccumulateHessianTimesInput
(void* ptr_fun, void* ptr_est, void* ptr_inp, int subset, void* ptr_out)
{
try {
auto& fun = objectFromHandle<ObjectiveFunction3DF>(ptr_fun);
auto& est = objectFromHandle<STIRImageData>(ptr_est);
auto& inp = objectFromHandle<STIRImageData>(ptr_inp);
auto& out = objectFromHandle<STIRImageData>(ptr_out);
auto& curr_est = est.data();
auto& input = inp.data();
auto& output = out.data();
if (subset >= 0)
fun.accumulate_sub_Hessian_times_input(output, curr_est, input, subset);
else {
for (int s = 0; s < fun.get_num_subsets(); s++) {
fun.accumulate_sub_Hessian_times_input(output, curr_est, input, s);
}
}
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_objectiveFunctionComputeHessianTimesInput
(void* ptr_fun, void* ptr_est, void* ptr_inp, int subset, void* ptr_out)
{
try {
auto& fun = objectFromHandle<xSTIR_GeneralisedObjectiveFunction3DF>(ptr_fun);
auto& est = objectFromHandle<STIRImageData>(ptr_est);
auto& inp = objectFromHandle<STIRImageData>(ptr_inp);
auto& out = objectFromHandle<STIRImageData>(ptr_out);
auto& curr_est = est.data();
auto& input = inp.data();
auto& output = out.data();
fun.multiply_with_Hessian(output, curr_est, input, subset);
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_setupPrior(void* ptr_p, void* ptr_i)
Expand Down Expand Up @@ -1288,7 +1345,7 @@ void*
cSTIR_priorGradient(void* ptr_p, void* ptr_i)
{
try {
Prior3DF& prior = objectFromHandle<Prior3DF>(ptr_p);
Prior3DF& prior = objectFromHandle<stir::GeneralisedPrior <Image3DF> >(ptr_p);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
Image3DF& image = id.data();
shared_ptr<STIRImageData> sptr(new STIRImageData(image));
Expand All @@ -1299,6 +1356,42 @@ cSTIR_priorGradient(void* ptr_p, void* ptr_i)
CATCH;
}

extern "C"
void*
cSTIR_priorAccumulateHessianTimesInput(void* ptr_prior, void* ptr_out, void* ptr_cur, void* ptr_inp)
{
try {
auto& prior = objectFromHandle<stir::GeneralisedPrior <Image3DF> >(ptr_prior);
auto& out = objectFromHandle<STIRImageData>(ptr_out);
auto& cur = objectFromHandle<STIRImageData>(ptr_cur);
auto& inp = objectFromHandle<STIRImageData>(ptr_inp);
auto& output = out.data();
auto& current = cur.data();
auto& input = inp.data();
prior.accumulate_Hessian_times_input(output, current, input);
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_priorComputeHessianTimesInput(void* ptr_prior, void* ptr_out, void* ptr_cur, void* ptr_inp)
{
try {
auto& prior = objectFromHandle<xSTIR_GeneralisedPrior3DF>(ptr_prior);
auto& out = objectFromHandle<STIRImageData>(ptr_out);
auto& cur = objectFromHandle<STIRImageData>(ptr_cur);
auto& inp = objectFromHandle<STIRImageData>(ptr_inp);
auto& output = out.data();
auto& current = cur.data();
auto& input = inp.data();
prior.multiply_with_Hessian(output, current, input);
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g)
Expand Down
5 changes: 5 additions & 0 deletions src/xSTIR/cSTIR/cstir_p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,11 @@ sirf::cSTIR_setPoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProj
else if (sirf::iequals(name, "cache_max_size")) {
obj_fun.set_cache_max_size(dataFromHandle<int>(hv));
}
else if (sirf::iequals(name, "subsensitivity_filenames"))
{
std::string s(charDataFromDataHandle(hv));
obj_fun.set_subsensitivity_filenames(s.c_str());
}
else
return parameterNotFound(name, __FILE__, __LINE__);
return new DataHandle;
Expand Down
9 changes: 9 additions & 0 deletions src/xSTIR/cSTIR/include/sirf/STIR/cstir.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ extern "C" {
void* cSTIR_convertListmodeToSinograms(void* ptr);
void* cSTIR_computeRandoms(void* ptr);
void* cSTIR_lm_num_prompts_exceeds_threshold(void* ptr, const float threshold);
void* cSTIR_objFunListModeSetInterval(void* ptr_f, size_t ptr_data);

// Data processor methods
void* cSTIR_setupImageDataProcessor(const void* ptr_p, void* ptr_i);
Expand Down Expand Up @@ -153,11 +154,19 @@ extern "C" {
(void* ptr_f, void* ptr_i, int subset);
void* cSTIR_computeObjectiveFunctionGradientNotDivided
(void* ptr_f, void* ptr_i, int subset, void* ptr_g);
void* cSTIR_objectiveFunctionAccumulateHessianTimesInput
(void* ptr_fun, void* ptr_est, void* ptr_inp, int subset, void* ptr_out);
void* cSTIR_objectiveFunctionComputeHessianTimesInput
(void* ptr_fun, void* ptr_est, void* ptr_inp, int subset, void* ptr_out);

// Prior methods
void* cSTIR_setupPrior(void* ptr_p, void* ptr_i);
void* cSTIR_priorValue(void* ptr_p, void* ptr_i);
void* cSTIR_priorGradient(void* ptr_p, void* ptr_i);
void* cSTIR_priorAccumulateHessianTimesInput
(void* ptr_prior, void* ptr_out, void* ptr_curr, void* ptr_inp);
void* cSTIR_priorComputeHessianTimesInput
(void* ptr_prior, void* ptr_out, void* ptr_cur, void* ptr_inp);
void* cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g);
void* cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir);

Expand Down
28 changes: 28 additions & 0 deletions src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,12 @@ The actual algorithm is described in

class xSTIR_GeneralisedPrior3DF : public stir::GeneralisedPrior < Image3DF > {
public:
void multiply_with_Hessian(Image3DF& output, const Image3DF& curr_image_est,
const Image3DF& input) const
{
output.fill(0.0);
accumulate_Hessian_times_input(output, curr_image_est, input);
}
// bool post_process() {
// return post_processing();
// }
Expand Down Expand Up @@ -1067,6 +1073,19 @@ The actual algorithm is described in
class xSTIR_GeneralisedObjectiveFunction3DF :
public stir::GeneralisedObjectiveFunction < Image3DF > {
public:
void multiply_with_Hessian(Image3DF& output, const Image3DF& curr_image_est,
const Image3DF& input, const int subset) const
{
output.fill(0.0);
if (subset >= 0)
accumulate_sub_Hessian_times_input(output, curr_image_est, input, subset);
else {
for (int s = 0; s < get_num_subsets(); s++) {
accumulate_sub_Hessian_times_input(output, curr_image_est, input, s);
}
}
}

// bool post_process() {
// return post_processing();
// }
Expand Down Expand Up @@ -1118,6 +1137,15 @@ The actual algorithm is described in
set_cache_path(filepath);
}

void set_time_interval(double start, double stop)
{
std::pair<double, double> interval(start, stop);
std::vector < std::pair<double, double> > intervals;
intervals.push_back(interval);
frame_defs = stir::TimeFrameDefinitions(intervals);
do_time_frame = true;
}

private:
//std::shared_ptr<PETAcquisitionData> sptr_ad_;
std::shared_ptr<PETAcquisitionModelUsingMatrix> sptr_am_;
Expand Down
52 changes: 51 additions & 1 deletion src/xSTIR/pSTIR/STIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,23 @@ def set_up(self, image):
"""Sets up."""
try_calling(pystir.cSTIR_setupPrior(self.handle, image.handle))

def accumulate_Hessian_times_input(self, current_estimate, input_, out=None):
"""Computes the multiplication of the Hessian with a vector and adds it to output.
"""
if out is None or out.handle is None:
out = input_.get_uniform_copy(0.0)
try_calling(pystir.cSTIR_priorAccumulateHessianTimesInput
(self.handle, out.handle, current_estimate.handle, input_.handle))
return out

def multiply_with_Hessian(self, current_estimate, input_, out=None):
"""Computes the multiplication of the Hessian at current_estimate with a vector.
"""
if out is None or out.handle is None:
out = input_.get_uniform_copy(0.0)
try_calling(pystir.cSTIR_priorComputeHessianTimesInput
(self.handle, current_estimate.handle, input_.handle, out.handle))
return out

class QuadraticPrior(Prior):
r"""Class for the prior that is a quadratic function of the image values.
Expand Down Expand Up @@ -2732,6 +2749,24 @@ def get_subset_gradient(self, image, subset, out=None):
"""
return self.gradient(image, subset, out)

def accumulate_Hessian_times_input(self, current_estimate, input_, subset=-1, out=None):
"""Computes the multiplication of the Hessian at current_estimate with a vector and adds it to output.
"""
if out is None or out.handle is None:
out = input_.clone()
evgueni-ovtchinnikov marked this conversation as resolved.
Show resolved Hide resolved
try_calling(pystir.cSTIR_objectiveFunctionAccumulateHessianTimesInput
(self.handle, current_estimate.handle, input_.handle, subset, out.handle))
return out

def multiply_with_Hessian(self, current_estimate, input_, subset=-1, out=None):
"""Computes the multiplication of the Hessian at current_estimate with a vector.
"""
if out is None or out.handle is None:
out = input_.get_uniform_copy(0.0)
try_calling(pystir.cSTIR_objectiveFunctionComputeHessianTimesInput
(self.handle, current_estimate.handle, input_.handle, subset, out.handle))
return out

@abc.abstractmethod
def get_subset_sensitivity(self, subset):
#print('in base class ObjectiveFunction')
Expand Down Expand Up @@ -2854,7 +2889,8 @@ def set_acquisition_data(self, ad):
self.handle, self.name, 'acquisition_data', ad.handle)


class PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin(ObjectiveFunction):
class PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin(PoissonLogLikelihoodWithLinearModelForMean):
#(ObjectiveFunction):
"""Class for a STIR type of Poisson loglikelihood object for listmode data.

Specifically, PoissonLogLikelihoodWithLinearModelForMeanAndListModeDataWithProjMatrixByBin.
Expand All @@ -2878,6 +2914,17 @@ def set_cache_path(self, path):
def get_cache_path(self):
return parms.char_par(self.handle, self.name, 'cache_path')

def set_time_interval(self, start, stop):
"""Sets the time interval.

Only data scanned during this time interval will be converted.
"""
interval = numpy.ndarray((2,), dtype=numpy.float32)
interval[0] = start
interval[1] = stop
try_calling(pystir.cSTIR_objFunListModeSetInterval(
self.handle, interval.ctypes.data))

def set_acquisition_data(self, ad):
assert_validity(ad, ListmodeData)
parms.set_parameter(
Expand Down Expand Up @@ -2918,6 +2965,9 @@ def set_cache_max_size(self, diff):
def get_cache_max_size(self):
return parms.int_par(self.handle, self.name, 'cache_max_size')

def set_subsensitivity_filenames(self, names):
return parms.set_char_par(self.handle, self.name, 'subsensitivity_filenames', names)

def get_subsensitivity_filenames(self):
return parms.char_par(self.handle, self.name, 'subsensitivity_filenames')

Expand Down
21 changes: 20 additions & 1 deletion src/xSTIR/pSTIR/tests/test_ObjectiveFunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def setUp(self):
templ = pet.AcquisitionData(os.path.join(data_path,'template_sinogram.hs'))
am.set_up(templ,image)
acquired_data=am.forward(image)

am.set_background_term(acquired_data*0 + numpy.mean(acquired_data.as_array()))
am.set_up(templ,image)
obj_fun = pet.make_Poisson_loglikelihood(acquired_data)
obj_fun.set_acquisition_model(am)
obj_fun.set_up(image)
Expand All @@ -56,3 +57,21 @@ def test_Poisson_loglikelihood_call(self):

numpy.testing.assert_almost_equal(a,b)

def test_Hessian(self, subset=-1, eps=1e-3):
"""Checks that grad(x + dx) - grad(x) is close to H(x)*dx
"""
x = self.image
dx = x.clone()
dx *= eps/dx.norm()
dx += eps/2
y = x + dx
gx = self.obj_fun.gradient(x, subset)
gy = self.obj_fun.gradient(y, subset)
dg = gy - gx
Hdx = self.obj_fun.multiply_with_Hessian(x, dx, subset)
q = (dg - Hdx).norm()/dg.norm()
print('norm of grad(x + dx) - grad(x): %f' % dg.norm())
print('norm of H(x)*dx: %f' % Hdx.norm())
print('relative difference: %f' % q)
assert q <= .002

Loading
Loading