diff --git a/src/xSTIR/cSTIR/cstir.cpp b/src/xSTIR/cSTIR/cstir.cpp index 9a4f051d9..bd431e98b 100644 --- a/src/xSTIR/cSTIR/cstir.cpp +++ b/src/xSTIR/cSTIR/cstir.cpp @@ -1186,6 +1186,33 @@ cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset) CATCH; } +extern "C" +void* +cSTIR_computeObjectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset, void* ptr_g) +{ + try { + ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f); + STIRImageData& id = objectFromHandle(ptr_i); + STIRImageData& gd = objectFromHandle(ptr_g); + Image3DF& image = id.data(); + Image3DF& grad = gd.data(); + if (subset >= 0) + fun.compute_sub_gradient(grad, image, subset); + else { + int nsub = fun.get_num_subsets(); + grad.fill(0.0); + shared_ptr sptr_sub(new STIRImageData(image)); + Image3DF& subgrad = sptr_sub->data(); + for (int sub = 0; sub < nsub; sub++) { + fun.compute_sub_gradient(subgrad, image, sub); + grad += subgrad; + } + } + return (void*) new DataHandle; + } + CATCH; +} + extern "C" void* cSTIR_objectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset) @@ -1204,6 +1231,24 @@ cSTIR_objectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset) CATCH; } +extern "C" +void* +cSTIR_computeObjectiveFunctionGradientNotDivided(void* ptr_f, void* ptr_i, int subset, void* ptr_g) +{ + try { + PoissonLogLhLinModMean3DF& fun = + objectFromHandle(ptr_f); + STIRImageData& id = objectFromHandle(ptr_i); + STIRImageData& gd = objectFromHandle(ptr_g); + Image3DF& image = id.data(); + Image3DF& grad = gd.data(); + fun.compute_sub_gradient_without_penalty_plus_sensitivity + (grad, image, subset); + return (void*) new DataHandle; + } + CATCH; +} + extern "C" void* cSTIR_setupPrior(void* ptr_p, void* ptr_i) @@ -1256,12 +1301,28 @@ cSTIR_priorGradient(void* ptr_p, void* ptr_i) extern "C" void* -cSTIR_PLSPriorGradient(void* ptr_p, int dir) +cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g) +{ + try { + Prior3DF& prior = objectFromHandle(ptr_p); + STIRImageData& id = objectFromHandle(ptr_i); + STIRImageData& gd = objectFromHandle(ptr_g); + Image3DF& image = id.data(); + Image3DF& grad = gd.data(); + prior.compute_gradient(grad, image); + return (void*) new DataHandle; + } + CATCH; +} + +extern "C" +void* +cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir) { try { PLSPrior& prior = objectFromHandle >(ptr_p); auto sptr_im = prior.get_anatomical_grad_sptr(dir); - auto sptr_id = std::make_shared(*sptr_im); + auto sptr_id = std::make_shared(*sptr_im); return newObjectHandle(sptr_id); } CATCH; diff --git a/src/xSTIR/cSTIR/include/sirf/STIR/cstir.h b/src/xSTIR/cSTIR/include/sirf/STIR/cstir.h index 86172ba83..e4530c93f 100644 --- a/src/xSTIR/cSTIR/include/sirf/STIR/cstir.h +++ b/src/xSTIR/cSTIR/include/sirf/STIR/cstir.h @@ -143,18 +143,23 @@ extern "C" { // Objective function methods void* cSTIR_setupObjectiveFunction(void* ptr_r, void* ptr_i); - void* cSTIR_subsetSensitivity(void* ptr_f, int subset); + void* cSTIR_subsetSensitivity(void* ptr_f, int subset); void* cSTIR_objectiveFunctionValue(void* ptr_f, void* ptr_i); void* cSTIR_objectiveFunctionGradient (void* ptr_f, void* ptr_i, int subset); + void* cSTIR_computeObjectiveFunctionGradient + (void* ptr_f, void* ptr_i, int subset, void* ptr_g); void* cSTIR_objectiveFunctionGradientNotDivided (void* ptr_f, void* ptr_i, int subset); + void* cSTIR_computeObjectiveFunctionGradientNotDivided + (void* ptr_f, void* ptr_i, int subset, void* ptr_g); // Prior methods void* cSTIR_setupPrior(void* ptr_p, void* ptr_i); void* cSTIR_priorValue(void* ptr_p, void* ptr_i); void* cSTIR_priorGradient(void* ptr_p, void* ptr_i); - void* cSTIR_PLSPriorGradient(void* ptr_p, int dir); + void* cSTIR_computePriorGradient(void* ptr_p, void* ptr_i, void* ptr_g); + void* cSTIR_PLSPriorAnatomicalGradient(void* ptr_p, int dir); // Image methods void* cSTIR_getImageDimensions(const void* ptr, PTR_INT ptr_data); diff --git a/src/xSTIR/pSTIR/STIR.py b/src/xSTIR/pSTIR/STIR.py index 7c02a62e8..7388f2aff 100644 --- a/src/xSTIR/pSTIR/STIR.py +++ b/src/xSTIR/pSTIR/STIR.py @@ -33,7 +33,7 @@ from deprecation import deprecated from sirf.Utilities import show_2D_array, show_3D_array, error, check_status, \ - try_calling, assert_validity, \ + try_calling, assert_validity, assert_validities, \ cpp_int_dtype, cpp_int_array, \ examples_data_path, existing_filepath, pTest from sirf import SIRF @@ -2307,22 +2307,27 @@ def value(self, image): """Returns the value of the prior (alias of get_value()).""" return self.get_value(image) - def get_gradient(self, image): + def get_gradient(self, image, out=None, **kwargs): """Returns gradient of the prior. Returns the value of the gradient of the prior for the specified image. image: ImageData object """ assert_validity(image, ImageData) - grad = ImageData() - grad.handle = pystir.cSTIR_priorGradient(self.handle, image.handle) - check_status(grad.handle) - return grad + if out is None: + out = ImageData() + if out.handle is None: + out.handle = pystir.cSTIR_priorGradient(self.handle, image.handle) + else: + assert_validities(image, out) + pystir.cSTIR_computePriorGradient(self.handle, image.handle, out.handle) + check_status(out.handle) + return out - def gradient(self, image): + def gradient(self, image, out=None, **kwargs): """Returns the gradient of the prior (alias of get_gradient()).""" - return self.get_gradient(image) + return self.get_gradient(image, out) def set_up(self, image): """Sets up.""" @@ -2566,10 +2571,13 @@ def get_anatomical_image(self): check_status(image.handle) return image - def get_anatomical_grad(self, direction): + def get_anatomical_grad(self, direction, out=None): """Returns anatomical gradient.""" - image = ImageData() - image.handle = pystir.cSTIR_PLSPriorGradient(self.handle, direction) + if out is None: + image = ImageData() + else: + image = out + image.handle = pystir.cSTIR_PLSPriorAnatomicalGradient(self.handle, direction) check_status(image.handle) return image @@ -2686,7 +2694,7 @@ def get_value(self, image): """ return self.value(image) - def gradient(self, image, subset=-1): + def gradient(self, image, subset=-1, out=None): """Returns the value of the additive component of the gradient of this objective function on the specified image corresponding to the @@ -2697,20 +2705,24 @@ def gradient(self, image, subset=-1): subset: Python integer scalar """ assert_validity(image, ImageData) - grad = ImageData() - grad.handle = pystir.cSTIR_objectiveFunctionGradient( - self.handle, image.handle, subset) - check_status(grad.handle) - return grad + if out is None: + out = ImageData() + if out.handle is None: + out.handle = pystir.cSTIR_objectiveFunctionGradient(self.handle, image.handle, subset) + else: + assert_validities(image, out) + pystir.cSTIR_computeObjectiveFunctionGradient(self.handle, image.handle, subset, out.handle) + check_status(out.handle) + return out - def get_gradient(self, image): + def get_gradient(self, image, out=None): """Returns the gradient of the objective function on specified image. image: ImageData object """ - return self.gradient(image) + return self.gradient(image, -1, out) - def get_subset_gradient(self, image, subset): + def get_subset_gradient(self, image, subset, out=None): """Returns the value of the additive component of the gradient of this objective function on corresponding to the specified @@ -2718,7 +2730,7 @@ def get_subset_gradient(self, image, subset): image: ImageData object subset: Python integer scalar """ - return self.gradient(image, subset) + return self.gradient(image, subset, out) @abc.abstractmethod def get_subset_sensitivity(self, subset): @@ -2767,18 +2779,24 @@ def get_subset_sensitivity(self, subset): check_status(ss.handle) return ss - def get_backprojection_of_acquisition_ratio(self, image, subset): + def get_backprojection_of_acquisition_ratio(self, image, subset, out=None): """Returns backprojection of measured to estimated acquisition ratio. Returns the back-projection of the ratio of the measured and estimated acquisition data. """ assert_validity(image, ImageData) - grad = ImageData() - grad.handle = pystir.cSTIR_objectiveFunctionGradientNotDivided( - self.handle, image.handle, subset) - check_status(grad.handle) - return grad + if out is None: + out = ImageData() + if out.handle is None: + out.handle = pystir.cSTIR_objectiveFunctionGradientNotDivided( + self.handle, image.handle, subset) + else: + assert_validities(image, out) + pystir.cSTIR_computeObjectiveFunctionGradientNotDivided( + self.handle, image.handle, subset, out.handle) + check_status(out.handle) + return out class PoissonLogLikelihoodWithLinearModelForMeanAndProjData(