Skip to content

Commit

Permalink
Added optional argument out in all gradient-computing methods (#1246)
Browse files Browse the repository at this point in the history
* added optional argument out in all gradient-computing methods

* got rid of out.fill(grad)

* fixed the use of out in computing gradients (except PLSPriorGradient)

* corrected C interface to PLSPrior anatomical gradient
  • Loading branch information
evgueni-ovtchinnikov authored May 9, 2024
1 parent 2c66faf commit 1d53570
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 31 deletions.
65 changes: 63 additions & 2 deletions src/xSTIR/cSTIR/cstir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,33 @@ cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset)
CATCH;
}

extern "C"
void*
cSTIR_computeObjectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset, void* ptr_g)
{
try {
ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
STIRImageData& gd = objectFromHandle<STIRImageData>(ptr_g);
Image3DF& image = id.data();
Image3DF& grad = gd.data();
if (subset >= 0)
fun.compute_sub_gradient(grad, image, subset);
else {
int nsub = fun.get_num_subsets();
grad.fill(0.0);
shared_ptr<STIRImageData> sptr_sub(new STIRImageData(image));
Image3DF& subgrad = sptr_sub->data();
for (int sub = 0; sub < nsub; sub++) {
fun.compute_sub_gradient(subgrad, image, sub);
grad += subgrad;
}
}
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_objectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset)
Expand All @@ -1204,6 +1231,24 @@ cSTIR_objectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset)
CATCH;
}

extern "C"
void*
cSTIR_computeObjectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset, void* ptr_g)
{
try {
PoissonLogLhLinModMean3DF& fun =
objectFromHandle<PoissonLogLhLinModMean3DF>(ptr_f);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
STIRImageData& gd = objectFromHandle<STIRImageData>(ptr_g);
Image3DF& image = id.data();
Image3DF& grad = gd.data();
fun.compute_sub_gradient_without_penalty_plus_sensitivity
(grad, image, subset);
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_setupPrior(void* ptr_p, void* ptr_i)
Expand Down Expand Up @@ -1256,12 +1301,28 @@ cSTIR_priorGradient(void* ptr_p, void* ptr_i)

extern "C"
void*
cSTIR_PLSPriorGradient(void* ptr_p, int dir)
cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g)
{
try {
Prior3DF& prior = objectFromHandle<Prior3DF>(ptr_p);
STIRImageData& id = objectFromHandle<STIRImageData>(ptr_i);
STIRImageData& gd = objectFromHandle<STIRImageData>(ptr_g);
Image3DF& image = id.data();
Image3DF& grad = gd.data();
prior.compute_gradient(grad, image);
return (void*) new DataHandle;
}
CATCH;
}

extern "C"
void*
cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir)
{
try {
PLSPrior<float>& prior = objectFromHandle<PLSPrior<float> >(ptr_p);
auto sptr_im = prior.get_anatomical_grad_sptr(dir);
auto sptr_id = std::make_shared<STIRImageData>(*sptr_im);
auto sptr_id = std::make_shared<STIRImageData>(*sptr_im);
return newObjectHandle(sptr_id);
}
CATCH;
Expand Down
9 changes: 7 additions & 2 deletions src/xSTIR/cSTIR/include/sirf/STIR/cstir.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,23 @@ extern "C" {

// Objective function methods
void* cSTIR_setupObjectiveFunction(void* ptr_r, void* ptr_i);
void* cSTIR_subsetSensitivity(void* ptr_f, int subset);
void* cSTIR_subsetSensitivity(void* ptr_f, int subset);
void* cSTIR_objectiveFunctionValue(void* ptr_f, void* ptr_i);
void* cSTIR_objectiveFunctionGradient
(void* ptr_f, void* ptr_i, int subset);
void* cSTIR_computeObjectiveFunctionGradient
(void* ptr_f, void* ptr_i, int subset, void* ptr_g);
void* cSTIR_objectiveFunctionGradientNotDivided
(void* ptr_f, void* ptr_i, int subset);
void* cSTIR_computeObjectiveFunctionGradientNotDivided
(void* ptr_f, void* ptr_i, int subset, void* ptr_g);

// Prior methods
void* cSTIR_setupPrior(void* ptr_p, void* ptr_i);
void* cSTIR_priorValue(void* ptr_p, void* ptr_i);
void* cSTIR_priorGradient(void* ptr_p, void* ptr_i);
void* cSTIR_PLSPriorGradient(void* ptr_p, int dir);
void* cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g);
void* cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir);

// Image methods
void* cSTIR_getImageDimensions(const void* ptr, PTR_INT ptr_data);
Expand Down
72 changes: 45 additions & 27 deletions src/xSTIR/pSTIR/STIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from deprecation import deprecated

from sirf.Utilities import show_2D_array, show_3D_array, error, check_status, \
try_calling, assert_validity, \
try_calling, assert_validity, assert_validities, \
cpp_int_dtype, cpp_int_array, \
examples_data_path, existing_filepath, pTest
from sirf import SIRF
Expand Down Expand Up @@ -2307,22 +2307,27 @@ def value(self, image):
"""Returns the value of the prior (alias of get_value())."""
return self.get_value(image)

def get_gradient(self, image):
def get_gradient(self, image, out=None, **kwargs):
"""Returns gradient of the prior.
Returns the value of the gradient of the prior for the specified image.
image: ImageData object
"""
assert_validity(image, ImageData)
grad = ImageData()
grad.handle = pystir.cSTIR_priorGradient(self.handle, image.handle)
check_status(grad.handle)
return grad
if out is None:
out = ImageData()
if out.handle is None:
out.handle = pystir.cSTIR_priorGradient(self.handle, image.handle)
else:
assert_validities(image, out)
pystir.cSTIR_computePriorGradient(self.handle, image.handle, out.handle)
check_status(out.handle)
return out

def gradient(self, image):
def gradient(self, image, out=None, **kwargs):
"""Returns the gradient of the prior (alias of get_gradient())."""

return self.get_gradient(image)
return self.get_gradient(image, out)

def set_up(self, image):
"""Sets up."""
Expand Down Expand Up @@ -2566,10 +2571,13 @@ def get_anatomical_image(self):
check_status(image.handle)
return image

def get_anatomical_grad(self, direction):
def get_anatomical_grad(self, direction, out=None):
"""Returns anatomical gradient."""
image = ImageData()
image.handle = pystir.cSTIR_PLSPriorGradient(self.handle, direction)
if out is None:
image = ImageData()
else:
image = out
image.handle = pystir.cSTIR_PLSPriorAnatomicalGradient(self.handle, direction)
check_status(image.handle)
return image

Expand Down Expand Up @@ -2686,7 +2694,7 @@ def get_value(self, image):
"""
return self.value(image)

def gradient(self, image, subset=-1):
def gradient(self, image, subset=-1, out=None):
"""Returns the value of the additive component of the gradient
of this objective function on the specified image corresponding to the
Expand All @@ -2697,28 +2705,32 @@ def gradient(self, image, subset=-1):
subset: Python integer scalar
"""
assert_validity(image, ImageData)
grad = ImageData()
grad.handle = pystir.cSTIR_objectiveFunctionGradient(
self.handle, image.handle, subset)
check_status(grad.handle)
return grad
if out is None:
out = ImageData()
if out.handle is None:
out.handle = pystir.cSTIR_objectiveFunctionGradient(self.handle, image.handle, subset)
else:
assert_validities(image, out)
pystir.cSTIR_computeObjectiveFunctionGradient(self.handle, image.handle, subset, out.handle)
check_status(out.handle)
return out

def get_gradient(self, image):
def get_gradient(self, image, out=None):
"""Returns the gradient of the objective function on specified image.
image: ImageData object
"""
return self.gradient(image)
return self.gradient(image, -1, out)

def get_subset_gradient(self, image, subset):
def get_subset_gradient(self, image, subset, out=None):
"""Returns the value of the additive component of the gradient
of this objective function on <image> corresponding to the specified
subset (see set_num_subsets() method).
image: ImageData object
subset: Python integer scalar
"""
return self.gradient(image, subset)
return self.gradient(image, subset, out)

@abc.abstractmethod
def get_subset_sensitivity(self, subset):
Expand Down Expand Up @@ -2767,18 +2779,24 @@ def get_subset_sensitivity(self, subset):
check_status(ss.handle)
return ss

def get_backprojection_of_acquisition_ratio(self, image, subset):
def get_backprojection_of_acquisition_ratio(self, image, subset, out=None):
"""Returns backprojection of measured to estimated acquisition ratio.
Returns the back-projection of the ratio of the measured and estimated
acquisition data.
"""
assert_validity(image, ImageData)
grad = ImageData()
grad.handle = pystir.cSTIR_objectiveFunctionGradientNotDivided(
self.handle, image.handle, subset)
check_status(grad.handle)
return grad
if out is None:
out = ImageData()
if out.handle is None:
out.handle = pystir.cSTIR_objectiveFunctionGradientNotDivided(
self.handle, image.handle, subset)
else:
assert_validities(image, out)
pystir.cSTIR_computeObjectiveFunctionGradientNotDivided(
self.handle, image.handle, subset, out.handle)
check_status(out.handle)
return out


class PoissonLogLikelihoodWithLinearModelForMeanAndProjData(
Expand Down

0 comments on commit 1d53570

Please sign in to comment.