From 4fee4ab7ec6088dc5a53a6dfc9ce4ab9518b7d30 Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 1 Jul 2020 15:04:57 +0100 Subject: [PATCH 01/42] implementation --- src/Registration/cReg/CMakeLists.txt | 4 +- src/Registration/cReg/NiftiImageData.cpp | 7 +- .../cReg/NiftiImageData3DBSpline.cpp | 92 +++++++++++++ .../cReg/include/sirf/Reg/NiftiImageData.h | 2 +- .../sirf/Reg/NiftiImageData3DBSpline.h | 121 ++++++++++++++++++ 5 files changed, 223 insertions(+), 3 deletions(-) create mode 100644 src/Registration/cReg/NiftiImageData3DBSpline.cpp create mode 100644 src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h diff --git a/src/Registration/cReg/CMakeLists.txt b/src/Registration/cReg/CMakeLists.txt index 5bd1a329c..6facaa51b 100644 --- a/src/Registration/cReg/CMakeLists.txt +++ b/src/Registration/cReg/CMakeLists.txt @@ -43,7 +43,9 @@ SET(SOURCES "NiftiImageData3D.cpp" "NiftiImageData3DTensor.cpp" "NiftiImageData3DDeformation.cpp" - "NiftiImageData3DDisplacement.cpp") + "NiftiImageData3DDisplacement.cpp" + "NiftiImageData3DBSPline.cpp" + ) # If we're also wrapping to python or matlab, include the c-files IF(BUILD_PYTHON OR BUILD_MATLAB) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index 2dd1caffc..76731e3e0 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -36,6 +36,7 @@ limitations under the License. #include "sirf/Reg/NiftiImageData3DTensor.h" #include "sirf/Reg/NiftiImageData3DDeformation.h" #include "sirf/Reg/NiftiImageData3DDisplacement.h" +#include "sirf/Reg/NiftiImageData3DBSpline.h" #include "sirf/Reg/AffineTransformation.h" #include "sirf/Reg/NiftyResample.h" #include @@ -492,7 +493,10 @@ void NiftiImageData::check_dimensions(const NiftiImageDataType image_t else if (image_type == _3D) { ndim= 3; nt= 1; nu= 1; intent_code = NIFTI_INTENT_NONE; intent_p1=-1; } else if (image_type == _3DTensor) { ndim= 5; nt= 1; nu= 3; intent_code = NIFTI_INTENT_VECTOR; intent_p1=-1; } else if (image_type == _3DDisp) { ndim= 5; nt= 1; nu= 3; intent_code = NIFTI_INTENT_VECTOR; intent_p1=DISP_FIELD; } - else /*if (image_type == _3DDef)*/ { ndim= 5; nt= 1; nu= 3; intent_code = NIFTI_INTENT_VECTOR; intent_p1=DEF_FIELD; } + else if (image_type == _3DDef) { ndim= 5; nt= 1; nu= 3; intent_code = NIFTI_INTENT_VECTOR; intent_p1=DEF_FIELD; } + else if (image_type == _3DBSpl) { ndim= 5; nt= 1; nu= 3; intent_code = NIFTI_INTENT_VECTOR; intent_p1=SPLINE_VEL_GRID; } + else + throw std::runtime_error("NiftiImageData::check_dimensions: Unknown image type"); // Check everthing is as it should be. -1 means we don't care about it // (e.g., NiftiImageData3D doesn't care about intent_p1, which is used by NiftyReg for Disp/Def fields) @@ -513,6 +517,7 @@ void NiftiImageData::check_dimensions(const NiftiImageDataType image_t else if (typeid(*this) == typeid(NiftiImageData3DTensor)) ss << "NiftiImageData3DTensor"; else if (typeid(*this) == typeid(NiftiImageData3DDisplacement)) ss << "NiftiImageData3DDisplacement"; else if (typeid(*this) == typeid(NiftiImageData3DDeformation)) ss << "NiftiImageData3DDeformation"; + else if (typeid(*this) == typeid(NiftiImageData3DBSpline)) ss << "NiftiImageData3DDeformation"; ss << ".\n\t\tExpected params: ndim = " << ndim << ", nu = " << nu << ", nt = " << nt; if (intent_code == NIFTI_INTENT_NONE) ss << ", intent_code = None"; else if (intent_code == NIFTI_INTENT_VECTOR) ss << ", intent_code = Vector"; diff --git a/src/Registration/cReg/NiftiImageData3DBSpline.cpp b/src/Registration/cReg/NiftiImageData3DBSpline.cpp new file mode 100644 index 000000000..1e34060cf --- /dev/null +++ b/src/Registration/cReg/NiftiImageData3DBSpline.cpp @@ -0,0 +1,92 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2017 - 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for deformation field transformations. + +\author Richard Brown +\author SyneRBI +*/ + +#include "sirf/Reg/NiftiImageData3DBSpline.h" +#include "sirf/Reg/NiftiImageData3DDeformation.h" +#include "sirf/NiftyMoMo/BSplineTransformation.h" + +using namespace sirf; + +template +NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, float spacing[]) +{ + // Get any of the tensor components as a 3d image + nifti_image *ref_ptr = def.get_tensor_component(0)->get_raw_nifti_sptr().get(); + // Create the NiftyMoMo bspline transformation class + NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing); + // Convert DVF to CPG + bspline.GetDVFGradientWRTTransformationParameters(def.clone()->get_raw_nifti_sptr().get(), ref_ptr); + // Get output + nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); + *this = NiftiImageData3DBSpline(*cpg_ptr); +} + +template +void NiftiImageData3DBSpline::create_from_3D_image(const NiftiImageData &image) +{ + NiftiImageData3DTensor::create_from_3D_image(image); + this->_nifti_image->intent_p1 = SPLINE_VEL_GRID; +} + +template +NiftiImageData3DDeformation NiftiImageData3DBSpline::get_as_deformation_field(const NiftiImageData &ref) const +{ + // Get spacing of reference image + float spacing[3]; + for (unsigned i=0; i<3; ++i) + spacing[i] = this->_nifti_image->pixdim[i+1]; + // Create the NiftyMoMo bspline transformation class + NiftyMoMo::BSplineTransformation bspline(ref.clone()->get_raw_nifti_sptr().get(), 1, spacing); + // Set the CPG + bspline.SetParameters(static_cast(this->_nifti_image->data), false); + // Get the DVF + nifti_image *output_def_ptr = bspline.GetDeformationVectorField(ref.get_raw_nifti_sptr().get()); + return NiftiImageData3DDeformation(*output_def_ptr); +} + +template +NiftiImageData3DBSpline* +NiftiImageData3DBSpline::get_inverse_impl_nr(const std::shared_ptr >) const +{ + throw std::runtime_error("NiftiImageData3DBSpline::get_inverse_impl_nr not yet implemented."); +} + +template +NiftiImageData3DBSpline* +NiftiImageData3DBSpline::get_inverse_impl_vtk(const std::shared_ptr >) const +{ + throw std::runtime_error("NiftiImageData3DBSpline::get_inverse_impl_vtk not yet implemented."); +#ifndef SIRF_VTK + throw std::runtime_error("Build SIRF with VTK support for this functionality"); +#endif +} + +namespace sirf { +template class NiftiImageData3DBSpline; +} diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index 6edd52bcb..013a3c1b1 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -420,7 +420,7 @@ class NiftiImageData : public ImageData protected: - enum NiftiImageDataType { _general, _3D, _3DTensor, _3DDisp, _3DDef}; + enum NiftiImageDataType { _general, _3D, _3DTensor, _3DDisp, _3DDef, _3DBSpl}; enum MathsType { add, sub, mul }; diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h new file mode 100644 index 000000000..6e6c85728 --- /dev/null +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h @@ -0,0 +1,121 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for b-spline control point grid SIRF image data. + +\author Richard Brown +\author SyneRBI +*/ + +#pragma once + +#include "sirf/Reg/NiftiImageData3DTensor.h" +#include "sirf/Reg/NonRigidTransformation.h" + +namespace sirf { + +/*! +\ingroup Registration +\brief Class for b-spline control point grid SIRF image data. + +\author Richard Brown +\author SyneRBI +*/ +template +class NiftiImageData3DBSpline : public NiftiImageData3DTensor, public NonRigidTransformation +{ +public: + /// Constructor + NiftiImageData3DBSpline() {} + + /// Filename constructor + NiftiImageData3DBSpline(const std::string &filename) + : NiftiImageData3DTensor(filename) { this->check_dimensions(this->_3DBSpl); } + + /// Nifti constructor + NiftiImageData3DBSpline(const nifti_image &image_nifti) + : NiftiImageData3DTensor(image_nifti) { this->check_dimensions(this->_3DBSpl); } + + /// Construct from general tensor + NiftiImageData3DBSpline(const NiftiImageData& tensor) + : NiftiImageData3DTensor(tensor) { this->check_dimensions(this->_3DBSpl); } + + /// Construct from array + template + NiftiImageData3DBSpline(const inputType * const data, const VoxelisedGeometricalInfo3D &geom) + : NiftiImageData3DTensor(data, geom) { this->_nifti_image->intent_code = NIFTI_INTENT_VECTOR; this->_nifti_image->intent_p1=SPLINE_VEL_GRID; } + + /// Create from 3 individual components + NiftiImageData3DBSpline(const NiftiImageData3D &x, const NiftiImageData3D &y, const NiftiImageData3D &z) + : NiftiImageData3DTensor(x,y,z) { this->check_dimensions(this->_3DBSpl); } + + /// Create from deformation field image + NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, float spacing[]); + + /// Create from 3D image + void create_from_3D_image(const NiftiImageData &image); + + /// Get as deformation field + virtual NiftiImageData3DDeformation get_as_deformation_field(const NiftiImageData &ref) const; + + /// New data handle + virtual ObjectHandle* new_data_container_handle() const + { + return new ObjectHandle + (std::shared_ptr(new NiftiImageData3DBSpline)); + } + /// Write + virtual void write(const std::string &filename) const { this->NiftiImageData::write(filename); } + /// Clone and return as unique pointer. + std::unique_ptr clone() const + { + return std::unique_ptr(this->clone_impl()); + } + + /*! \brief Get inverse as unique pointer (potentially based on another image). + * + * Why would you want to base it on another image? Well, we might have a deformation + * that takes us from image A to B. We'll probably want the inverse to take us from + * image B back to A. In this case, use get_inverse(A). This is because the the deformation + * field is defined for the reference image. In the second case, A is the reference, + * and B is the floating image.*/ + std::unique_ptr get_inverse(const std::shared_ptr > image_sptr = nullptr, const bool use_vtk=false) const + { + throw std::runtime_error("NiftiImageData3DBSpline::get_inverse: not yet implemented"); + } + + +protected: + /// Clone helper function. Don't use. + virtual NiftiImageData3DBSpline* clone_impl() const + { + return new NiftiImageData3DBSpline(*this); + } + + /// Helper function for get_inverse (NiftyReg). Don't use. + virtual NiftiImageData3DBSpline* get_inverse_impl_nr(const std::shared_ptr > image_sptr = nullptr) const; + + /// Helper function for get_inverse (VTK). Don't use. + virtual NiftiImageData3DBSpline* get_inverse_impl_vtk(const std::shared_ptr > image_sptr = nullptr) const; +}; +} From 1686c4db62c48364668da1502e05711c81448cab Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 2 Jul 2020 09:35:14 +0100 Subject: [PATCH 02/42] current progress --- .../cReg/NiftiImageData3DBSpline.cpp | 2 ++ src/Registration/cReg/tests/test_cReg.cpp | 36 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/Registration/cReg/NiftiImageData3DBSpline.cpp b/src/Registration/cReg/NiftiImageData3DBSpline.cpp index 1e34060cf..e61d08d4b 100644 --- a/src/Registration/cReg/NiftiImageData3DBSpline.cpp +++ b/src/Registration/cReg/NiftiImageData3DBSpline.cpp @@ -44,7 +44,9 @@ NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3 bspline.GetDVFGradientWRTTransformationParameters(def.clone()->get_raw_nifti_sptr().get(), ref_ptr); // Get output nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); + cpg_ptr->intent_p1 = SPLINE_VEL_GRID; *this = NiftiImageData3DBSpline(*cpg_ptr); + this->check_dimensions(NiftiImageData::_3DBSpl); } template diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index b1fa09744..64e07ed76 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -37,6 +37,7 @@ limitations under the License. #include "sirf/Reg/NiftiImageData3DDisplacement.h" #include "sirf/Reg/AffineTransformation.h" #include "sirf/Reg/Quaternion.h" +#include "sirf/Reg/NiftiImageData3DBSpline.h" #include #include #ifdef SIRF_SPM @@ -1144,6 +1145,41 @@ int main(int argc, char* argv[]) std::cout << "// Finished weighted mean test.\n"; std::cout << "//------------------------------------------------------------------------ //\n"; } + { + + std::cout << "// ----------------------------------------------------------------------- //\n"; + std::cout << "// Starting CGP<->DVF test...\n"; + std::cout << "//------------------------------------------------------------------------ //\n"; + + auto dvf_sptr = std::dynamic_pointer_cast >( + NA.get_deformation_field_forward_sptr()); + + // DVF->CPG + float spacing[3]; + for (unsigned i=0; i<3; ++i) + spacing[i] = dvf_sptr->get_raw_nifti_sptr()->pixdim[i+1] * 2.f; + NiftiImageData3DBSpline dvf_to_cpg(*dvf_sptr, spacing); + NiftiImageData::print_headers({dvf_sptr.get(), &dvf_to_cpg}); + exit(0); + if (std::abs(dvf_to_cpg.get_max()) < 1.e-4f || std::abs(dvf_to_cpg.get_min()) < 1.e-4f) + throw std::runtime_error("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes."); + + // DVF->CPG->DVF + auto dvf_to_cpg_to_dvf = dvf_to_cpg.get_as_deformation_field(*dvf_sptr->get_tensor_component(0)); + + NiftiImageData::print_headers({ref_aladin.get(), dvf_sptr.get(), + &dvf_to_cpg, &dvf_to_cpg_to_dvf}); + + // Compare + if (*dvf_sptr != dvf_to_cpg_to_dvf) + throw std::runtime_error("DVF->CPG->DVF != DVF."); + +exit(0); + + std::cout << "// ----------------------------------------------------------------------- //\n"; + std::cout << "// Finished CGP<->DVF test.\n"; + std::cout << "//------------------------------------------------------------------------ //\n"; + } /* TODO UNCOMMENT WHEN GEOMETRICAL INFO IS IMPLEMENTED { std::cout << "// ----------------------------------------------------------------------- //\n"; From b62b75ee097607860f6bab5377487de467b9f0c9 Mon Sep 17 00:00:00 2001 From: Richard Date: Thu, 2 Jul 2020 09:27:08 +0000 Subject: [PATCH 03/42] works --- src/Registration/cReg/NiftiImageData3DBSpline.cpp | 9 +++++---- src/Registration/cReg/tests/test_cReg.cpp | 4 ---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData3DBSpline.cpp b/src/Registration/cReg/NiftiImageData3DBSpline.cpp index e61d08d4b..bc9b2069b 100644 --- a/src/Registration/cReg/NiftiImageData3DBSpline.cpp +++ b/src/Registration/cReg/NiftiImageData3DBSpline.cpp @@ -37,13 +37,14 @@ template NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, float spacing[]) { // Get any of the tensor components as a 3d image - nifti_image *ref_ptr = def.get_tensor_component(0)->get_raw_nifti_sptr().get(); + auto ref_sptr = def.get_tensor_component(0); + nifti_image *ref_ptr = ref_sptr->get_raw_nifti_sptr().get(); // Create the NiftyMoMo bspline transformation class NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing); - // Convert DVF to CPG - bspline.GetDVFGradientWRTTransformationParameters(def.clone()->get_raw_nifti_sptr().get(), ref_ptr); - // Get output + // Get cpg_ptr nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); + // Convert DVF to CPG + cpg_ptr->data = bspline.GetDVFGradientWRTTransformationParameters(def.clone()->get_raw_nifti_sptr().get(), ref_ptr); cpg_ptr->intent_p1 = SPLINE_VEL_GRID; *this = NiftiImageData3DBSpline(*cpg_ptr); this->check_dimensions(NiftiImageData::_3DBSpl); diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 64e07ed76..3715b52f5 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1159,8 +1159,6 @@ int main(int argc, char* argv[]) for (unsigned i=0; i<3; ++i) spacing[i] = dvf_sptr->get_raw_nifti_sptr()->pixdim[i+1] * 2.f; NiftiImageData3DBSpline dvf_to_cpg(*dvf_sptr, spacing); - NiftiImageData::print_headers({dvf_sptr.get(), &dvf_to_cpg}); - exit(0); if (std::abs(dvf_to_cpg.get_max()) < 1.e-4f || std::abs(dvf_to_cpg.get_min()) < 1.e-4f) throw std::runtime_error("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes."); @@ -1174,8 +1172,6 @@ int main(int argc, char* argv[]) if (*dvf_sptr != dvf_to_cpg_to_dvf) throw std::runtime_error("DVF->CPG->DVF != DVF."); -exit(0); - std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Finished CGP<->DVF test.\n"; std::cout << "//------------------------------------------------------------------------ //\n"; From 6c69d09a04fafc8faf50a99116b610f88a999c8e Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 2 Jul 2020 11:37:47 +0100 Subject: [PATCH 04/42] add converter --- src/Registration/cReg/CMakeLists.txt | 1 + ...ControlPointGridToDeformationConverter.cpp | 95 +++++++++++++++++++ .../ControlPointGridToDeformationConverter.h | 76 +++++++++++++++ src/Registration/cReg/tests/test_cReg.cpp | 12 +++ 4 files changed, 184 insertions(+) create mode 100644 src/Registration/cReg/ControlPointGridToDeformationConverter.cpp create mode 100644 src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h diff --git a/src/Registration/cReg/CMakeLists.txt b/src/Registration/cReg/CMakeLists.txt index 6facaa51b..8ad57bf57 100644 --- a/src/Registration/cReg/CMakeLists.txt +++ b/src/Registration/cReg/CMakeLists.txt @@ -45,6 +45,7 @@ SET(SOURCES "NiftiImageData3DDeformation.cpp" "NiftiImageData3DDisplacement.cpp" "NiftiImageData3DBSPline.cpp" + "ControlPointGridToDeformationConverter.cpp" ) # If we're also wrapping to python or matlab, include the c-files diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp new file mode 100644 index 000000000..97532c92c --- /dev/null +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -0,0 +1,95 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for converting control point grids to deformation field transformations. + +\author Richard Brown +\author SyneRBI +*/ + +#include "sirf/Reg/ControlPointGridToDeformationConverter.h" +#include "sirf/Reg/NiftiImageData3DDeformation.h" +#include "sirf/Reg/NiftiImageData3DBSpline.h" + +using namespace sirf; + +template +ControlPointGridToDeformationConverter:: +ControlPointGridToDeformationConverter() +{ + for (unsigned i=0; i<3; ++i) + _spacing[i] = std::numeric_limits::quiet_NaN(); +} + +template +void +ControlPointGridToDeformationConverter:: +set_cpg_spacing(const float spacing[3]) +{ + for (unsigned i=0; i<3; ++i) + _spacing[i] = spacing[i]; +} + +template +void +ControlPointGridToDeformationConverter:: +set_reference_image(const NiftiImageData &ref) +{ + _template_ref_sptr = ref.clone(); +} + +template +NiftiImageData3DDeformation +ControlPointGridToDeformationConverter:: +forward(const NiftiImageData3DBSpline &cpg) +{ + check_is_set_up(); + return cpg.get_as_deformation_field(*_template_ref_sptr); +} + +template +NiftiImageData3DBSpline +ControlPointGridToDeformationConverter:: +backward(const NiftiImageData3DDeformation &dvf) +{ + check_is_set_up(); + return NiftiImageData3DBSpline(dvf); +} + +template +void ControlPointGridToDeformationConverter:: +check_is_set_up() const +{ + // Has spacing been set? + for (unsigned i=0; i<3; ++i) + if (std::isnan(_spacing[i])) + throw std::runtime_error("ControlPointGridToDeformationConverter: Set CPG spacing."); + + // Has template deformation been set? + if (!_template_ref_sptr) + throw std::runtime_error("ControlPointGridToDeformationConverter: Set template DVF."); +} + +namespace sirf { +template class ControlPointGridToDeformationConverter; +} diff --git a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h new file mode 100644 index 000000000..729385175 --- /dev/null +++ b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h @@ -0,0 +1,76 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for converting control point grids to deformation field transformations. + +\author Richard Brown +\author SyneRBI +*/ + +#pragma once + +#include + +namespace sirf { + +// Forward declarations +template class NiftiImageData; +template class NiftiImageData3DDeformation; +template class NiftiImageData3DBSpline; + +/*! +\ingroup Registration +\brief Class for converting control point grids to deformation field transformations. + +\author Richard Brown +\author SyneRBI +*/ +template +class ControlPointGridToDeformationConverter +{ +public: + + /// Constructor + ControlPointGridToDeformationConverter(); + + /// Set CPG spacing + void set_cpg_spacing(const float spacing[3]); + + /// Set reference image for generating dvfs + void set_reference_image(const NiftiImageData &ref); + + /// CPG to DVF + NiftiImageData3DDeformation forward(const NiftiImageData3DBSpline &cpg); + + /// DVF to CPG + NiftiImageData3DBSpline backward(const NiftiImageData3DDeformation &dvf); + +private: + + /// Check is set up + void check_is_set_up() const; + + float _spacing[3]; + std::shared_ptr > _template_ref_sptr; +}; +} diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 3715b52f5..aa02633c0 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -38,6 +38,7 @@ limitations under the License. #include "sirf/Reg/AffineTransformation.h" #include "sirf/Reg/Quaternion.h" #include "sirf/Reg/NiftiImageData3DBSpline.h" +#include "sirf/Reg/ControlPointGridToDeformationConverter.h" #include #include #ifdef SIRF_SPM @@ -1172,6 +1173,17 @@ int main(int argc, char* argv[]) if (*dvf_sptr != dvf_to_cpg_to_dvf) throw std::runtime_error("DVF->CPG->DVF != DVF."); + // Do the same, using the converter + ControlPointGridToDeformationConverter cpg_2_dvf_converter; + cpg_2_dvf_converter.set_cpg_spacing(spacing); + cpg_2_dvf_converter.set_reference_image(*dvf_sptr->get_tensor_component(0)); + auto dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(*dvf_sptr); + auto dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter); + + // Compare + if (dvf_to_cpg_to_dvf != dvf_to_cpg_to_dvf_w_converter) + throw std::runtime_error("ControlPointGridToDeformationConverter DVF->CPG->DVF failed."); + std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Finished CGP<->DVF test.\n"; std::cout << "//------------------------------------------------------------------------ //\n"; From c38877bfb59874cfabd79709b92aa9497e867915 Mon Sep 17 00:00:00 2001 From: Richard Date: Thu, 2 Jul 2020 10:38:54 +0000 Subject: [PATCH 05/42] correct case --- src/Registration/cReg/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registration/cReg/CMakeLists.txt b/src/Registration/cReg/CMakeLists.txt index 8ad57bf57..d8d4069ca 100644 --- a/src/Registration/cReg/CMakeLists.txt +++ b/src/Registration/cReg/CMakeLists.txt @@ -44,7 +44,7 @@ SET(SOURCES "NiftiImageData3DTensor.cpp" "NiftiImageData3DDeformation.cpp" "NiftiImageData3DDisplacement.cpp" - "NiftiImageData3DBSPline.cpp" + "NiftiImageData3DBSpline.cpp" "ControlPointGridToDeformationConverter.cpp" ) From 52d3e5bc737366d19fbf3fe113a22ed9de9b7d4b Mon Sep 17 00:00:00 2001 From: Richard Date: Thu, 2 Jul 2020 10:41:47 +0000 Subject: [PATCH 06/42] converter works --- .../cReg/ControlPointGridToDeformationConverter.cpp | 2 +- src/Registration/cReg/NiftiImageData.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index 97532c92c..024d2e762 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -73,7 +73,7 @@ ControlPointGridToDeformationConverter:: backward(const NiftiImageData3DDeformation &dvf) { check_is_set_up(); - return NiftiImageData3DBSpline(dvf); + return NiftiImageData3DBSpline(dvf, _spacing); } template diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index 76731e3e0..a5af16718 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -517,7 +517,7 @@ void NiftiImageData::check_dimensions(const NiftiImageDataType image_t else if (typeid(*this) == typeid(NiftiImageData3DTensor)) ss << "NiftiImageData3DTensor"; else if (typeid(*this) == typeid(NiftiImageData3DDisplacement)) ss << "NiftiImageData3DDisplacement"; else if (typeid(*this) == typeid(NiftiImageData3DDeformation)) ss << "NiftiImageData3DDeformation"; - else if (typeid(*this) == typeid(NiftiImageData3DBSpline)) ss << "NiftiImageData3DDeformation"; + else if (typeid(*this) == typeid(NiftiImageData3DBSpline)) ss << "NiftiImageData3DBSpline"; ss << ".\n\t\tExpected params: ndim = " << ndim << ", nu = " << nu << ", nt = " << nt; if (intent_code == NIFTI_INTENT_NONE) ss << ", intent_code = None"; else if (intent_code == NIFTI_INTENT_VECTOR) ss << ", intent_code = Vector"; From 0d2137ce2ca8f90388e2d6f1c4934ac5556c19ae Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 2 Jul 2020 17:57:13 +0100 Subject: [PATCH 07/42] exported to python. tests pass --- .../cReg/NiftiImageData3DBSpline.cpp | 6 +- src/Registration/cReg/cReg.cpp | 99 +++++++++++++++++++ .../sirf/Reg/NiftiImageData3DBSpline.h | 2 +- src/Registration/cReg/include/sirf/Reg/cReg.h | 10 ++ src/Registration/pReg/Reg.py.in | 95 +++++++++++++++++- src/Registration/pReg/tests/test_pReg.py | 43 ++++++++ 6 files changed, 251 insertions(+), 4 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData3DBSpline.cpp b/src/Registration/cReg/NiftiImageData3DBSpline.cpp index bc9b2069b..c43fb624c 100644 --- a/src/Registration/cReg/NiftiImageData3DBSpline.cpp +++ b/src/Registration/cReg/NiftiImageData3DBSpline.cpp @@ -34,13 +34,15 @@ limitations under the License. using namespace sirf; template -NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, float spacing[]) +NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, const float spacing[]) { + // not marked const, so copy + float spacing_nonconst[3] = {spacing[0], spacing[1], spacing[2]}; // Get any of the tensor components as a 3d image auto ref_sptr = def.get_tensor_component(0); nifti_image *ref_ptr = ref_sptr->get_raw_nifti_sptr().get(); // Create the NiftyMoMo bspline transformation class - NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing); + NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing_nonconst); // Get cpg_ptr nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); // Convert DVF to CPG diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index a06fff828..ba9822161 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -26,6 +26,8 @@ limitations under the License. #include "sirf/Reg/NiftiImageData3DTensor.h" #include "sirf/Reg/NiftiImageData3DDisplacement.h" #include "sirf/Reg/NiftiImageData3DDeformation.h" +#include "sirf/Reg/NiftiImageData3DBSpline.h" +#include "sirf/Reg/ControlPointGridToDeformationConverter.h" #include "sirf/Reg/NiftyAladinSym.h" #include "sirf/Reg/NiftyF3dSym.h" #include "sirf/Reg/NiftyResample.h" @@ -68,6 +70,10 @@ void* cReg_newObject(const char* name) return newObjectHandle(std::shared_ptr >(new NiftiImageData3DDisplacement)); if (strcmp(name, "NiftiImageData3DDeformation") == 0) return newObjectHandle(std::shared_ptr >(new NiftiImageData3DDeformation)); + if (strcmp(name, "NiftiImageData3DBSpline") == 0) + return newObjectHandle(std::shared_ptr >(new NiftiImageData3DBSpline)); + if (strcmp(name, "ControlPointGridToDeformationConverter") == 0) + return newObjectHandle(std::shared_ptr >(new ControlPointGridToDeformationConverter)); if (strcmp(name, "NiftyAladinSym") == 0) return newObjectHandle(std::shared_ptr >(new NiftyAladinSym)); if (strcmp(name, "NiftyF3dSym") == 0) @@ -159,6 +165,11 @@ void* cReg_objectFromFile(const char* name, const char* filename) sptr(new NiftiImageData3DDeformation(filename)); return newObjectHandle(sptr); } + if (strcmp(name, "NiftiImageData3DBSpline") == 0) { + std::shared_ptr > + sptr(new NiftiImageData3DBSpline(filename)); + return newObjectHandle(sptr); + } if (strcmp(name, "AffineTransformation") == 0) { std::shared_ptr > sptr(new AffineTransformation(filename)); @@ -533,6 +544,12 @@ void* cReg_NiftiImageData3DTensor_construct_from_3_components(const char* obj, c sptr.reset(new NiftiImageData3DDisplacement(x,y,z)); else if (strcmp(obj,"NiftiImageData3DDeformation") == 0) sptr.reset(new NiftiImageData3DDeformation(x,y,z)); + else if (strcmp(obj,"NiftiImageData3DBSpline") == 0) + sptr.reset(new NiftiImageData3DBSpline(x,y,z)); + else + throw std::runtime_error( + "cReg_NiftiImageData3DTensor_construct_from_3_components, unknown type:" + + std::string(obj)); return newObjectHandle(sptr); } CATCH; @@ -547,6 +564,17 @@ void* cReg_NiftiImageData3DTensor_flip_component(const void *ptr, const int dim) } CATCH; } +extern "C" +void* cReg_NiftiImageData3DTensor_get_tensor_component(const void *ptr, const int dim) +{ + try { + NiftiImageData3DTensor& im = objectFromHandle >(ptr); + std::shared_ptr > im_sptr = im.get_tensor_component(dim); + std::shared_ptr > im3D_sptr = std::make_shared >(*im_sptr); + return newObjectHandle(im3D_sptr); + } + CATCH; +} // -------------------------------------------------------------------------------- // // NiftiImageData3DDeformation // -------------------------------------------------------------------------------- // @@ -610,6 +638,75 @@ void* cReg_NiftiImageData3DDisplacement_create_from_def(const void* def_ptr) CATCH; } +// -------------------------------------------------------------------------------- // +// NiftiImageData3DBSpline +// -------------------------------------------------------------------------------- // +extern "C" +void* cReg_NiftiImageData3DBSpline_create_from_def(const void* def_ptr, const float spacing_x, const float spacing_y, const float spacing_z) +{ + try { + NiftiImageData3DDeformation& def = objectFromHandle >(def_ptr); + const float spacing[3] = {spacing_x, spacing_y, spacing_z}; + return newObjectHandle(std::make_shared >(def, spacing)); + } + CATCH; +} + +// -------------------------------------------------------------------------------- // +// ControlPointGridToDeformationConverter +// -------------------------------------------------------------------------------- // +extern "C" +void* cReg_CPG2DVF_set_cpg_spacing(const void* converter_ptr, const float spacing_x, const float spacing_y, const float spacing_z) +{ + try { + ControlPointGridToDeformationConverter& cpg_2_dvf_converter = + objectFromHandle >(converter_ptr); + const float spacing[3] = {spacing_x, spacing_y, spacing_z}; + cpg_2_dvf_converter.set_cpg_spacing(spacing); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_CPG2DVF_set_ref_im(const void* converter_ptr, const void* ref_im_ptr) +{ + try { + ControlPointGridToDeformationConverter& cpg_2_dvf_converter = + objectFromHandle >(converter_ptr); + NiftiImageData& ref_im = + objectFromHandle >(ref_im_ptr); + cpg_2_dvf_converter.set_reference_image(ref_im); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_CPG2DVF_forward(const void* converter_ptr, const void* cpg_ptr) +{ + try { + ControlPointGridToDeformationConverter& cpg_2_dvf_converter = + objectFromHandle >(converter_ptr); + NiftiImageData3DBSpline& cpg = + objectFromHandle >(cpg_ptr); + NiftiImageData3DDeformation def = cpg_2_dvf_converter.forward(cpg); + return newObjectHandle(std::make_shared >(def)); + } + CATCH; +} +extern "C" +void* cReg_CPG2DVF_backward(const void* converter_ptr, const void* dvf_ptr) +{ + try { + ControlPointGridToDeformationConverter& cpg_2_dvf_converter = + objectFromHandle >(converter_ptr); + NiftiImageData3DDeformation& dvf = + objectFromHandle >(dvf_ptr); + NiftiImageData3DBSpline cpg = cpg_2_dvf_converter.backward(dvf); + return newObjectHandle(std::make_shared >(cpg)); + } + CATCH; +} + // -------------------------------------------------------------------------------- // // Registration // -------------------------------------------------------------------------------- // @@ -903,6 +1000,8 @@ void* cReg_Transformation_get_as_deformation_field(const void* ptr, const char* trans = &objectFromHandle >(ptr); else if (strcmp(name,"NiftiImageData3DDeformation") == 0) trans = &objectFromHandle >(ptr); + else if (strcmp(name,"NiftiImageData3DBSpline") == 0) + trans = &objectFromHandle >(ptr); else throw std::runtime_error("cReg_Transformation_get_as_deformation_field: type should be affine, disp or def."); diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h index 6e6c85728..f41e0f978 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h @@ -70,7 +70,7 @@ class NiftiImageData3DBSpline : public NiftiImageData3DTensor, public : NiftiImageData3DTensor(x,y,z) { this->check_dimensions(this->_3DBSpl); } /// Create from deformation field image - NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, float spacing[]); + NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, const float spacing[]); /// Create from 3D image void create_from_3D_image(const NiftiImageData &image); diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 2172ff5d0..931a7ff0f 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -71,6 +71,7 @@ extern "C" { void* cReg_NiftiImageData3DTensor_create_from_3D_image(const void *ptr, const void* obj); void* cReg_NiftiImageData3DTensor_construct_from_3_components(const char* obj, const void *x_ptr, const void *y_ptr, const void *z_ptr); void* cReg_NiftiImageData3DTensor_flip_component(const void *ptr, const int dim); + void* cReg_NiftiImageData3DTensor_get_tensor_component(const void *ptr, const int dim); // NiftiImageData3DDeformation void* cReg_NiftiImageData3DDeformation_compose_single_deformation(const void* im, const char* types, const void* trans_vector_ptr); @@ -80,6 +81,15 @@ extern "C" { // NiftiImageData3DDisplacement void* cReg_NiftiImageData3DDisplacement_create_from_def(const void* def_ptr); + // NiftiImageData3DBSpline + void* cReg_NiftiImageData3DBSpline_create_from_def(const void* def_ptr, const float spacing_x, const float spacing_y, const float spacing_z); + + // ControlPointGridToDeformationConverter + void* cReg_CPG2DVF_set_cpg_spacing(const void* converter_ptr, const float spacing_x, const float spacing_y, const float spacing_z); + void* cReg_CPG2DVF_set_ref_im(const void* converter_ptr, const void* ref_im_ptr); + void* cReg_CPG2DVF_forward(const void* converter_ptr, const void* cpg_ptr); + void* cReg_CPG2DVF_backward(const void* converter_ptr, const void* dvf_ptr); + // Registration void* cReg_Registration_process(void* ptr); void* cReg_Registration_get_deformation_displacement_image(const void* ptr, const char *transform_type, const int idx); diff --git a/src/Registration/pReg/Reg.py.in b/src/Registration/pReg/Reg.py.in index 14d9dc65a..b840ed0e9 100644 --- a/src/Registration/pReg/Reg.py.in +++ b/src/Registration/pReg/Reg.py.in @@ -22,7 +22,8 @@ Object-Oriented wrap for the cReg-to-Python interface pyreg.py import abc import sys -from pUtilities import * +from sirf.Utilities import assert_validity, \ + check_status, try_calling, inspect from sirf import SIRF import pyiutilities as pyiutil import pyreg @@ -519,6 +520,17 @@ class NiftiImageData3DTensor(NiftiImageData): try_calling(pyreg.cReg_NiftiImageData3DTensor_flip_component(self.handle, dim)) check_status(self.handle) + def get_tensor_component(self, dim): + """Get tensor component (i.e., nu=3 -> nu=1).""" + if 0 < dim or dim > 2: + raise AssertionError( + "Tensor component to extract should be between 0 and 2.") + output = NiftiImageData3D() + output.handle = pyreg.cReg_NiftiImageData3DTensor_get_tensor_component( + self.handle, dim) + check_status(output.handle) + return output + class NiftiImageData3DDisplacement(NiftiImageData3DTensor, _Transformation): """ @@ -627,6 +639,87 @@ class NiftiImageData3DDeformation(NiftiImageData3DTensor, _Transformation): return z +class NiftiImageData3DBSpline(NiftiImageData3DTensor, _Transformation): + """ + Class for 3D b-spline nifti image data. + """ + + def __init__(self, src1=None, src2=None, src3=None): + self.handle = None + self.name = 'NiftiImageData3DBSpline' + if src1 is None: + self.handle = pyreg.cReg_newObject(self.name) + # filename + elif isinstance(src1, str): + self.handle = pyreg.cReg_objectFromFile(self.name, src1) + # 3 x scalar images + elif isinstance(src1, NiftiImageData3D) and \ + isinstance(src2, NiftiImageData3D) and \ + isinstance(src3, NiftiImageData3D): + self.handle = pyreg.\ + cReg_NiftiImageData3DTensor_construct_from_3_components( + self.name, src1.handle, src2.handle, src3.handle) + # from deformation + elif isinstance(src1, NiftiImageData3DDeformation) and len(src2) == 3: + spacing = src2 + self.handle = pyreg.\ + cReg_NiftiImageData3DBSpline_create_from_def(src1.handle, + float(spacing[0]), + float(spacing[1]), + float(spacing[2])) + else: + raise error('Wrong source in NiftiImageData3DBSpline constructor') + check_status(self.handle) + + def __del__(self): + if self.handle is not None: + pyiutil.deleteDataHandle(self.handle) + + +class ControlPointGridToDeformationConverter(object): + """ + Class for converting from control points grids to deformations and vice + versa. + """ + def __init__(self): + self.handle = None + self.name = 'ControlPointGridToDeformationConverter' + self.handle = pyreg.cReg_newObject(self.name) + check_status(self.handle) + + def __del__(self): + if self.handle is not None: + pyiutil.deleteDataHandle(self.handle) + + def set_cpg_spacing(self, spacing): + """Set CPG spacing.""" + if len(spacing) != 3: + raise AssertionError("Spacing should be array of 3 numbers.") + try_calling(pyreg.cReg_CPG2DVF_set_cpg_spacing(self.handle, + float(spacing[0]), float(spacing[1]), float(spacing[2]))) + + def set_reference_image(self, ref_im): + """Set reference image for generating dvfs.""" + assert_validity(ref_im, NiftiImageData3D) + try_calling(pyreg.cReg_CPG2DVF_set_ref_im(self.handle, ref_im.handle)) + + def forward(self, cpg): + """CPG to DVF.""" + assert_validity(cpg, NiftiImageData3DBSpline) + output = NiftiImageData3DDeformation() + output.handle = pyreg.cReg_CPG2DVF_forward(self.handle, cpg.handle) + check_status(output.handle) + return output + + def backward(self, dvf): + """DVF to CPG""" + assert_validity(dvf, NiftiImageData3DDeformation) + output = NiftiImageData3DBSpline() + output.handle = pyreg.cReg_CPG2DVF_backward(self.handle, dvf.handle) + check_status(output.handle) + return output + + class _Registration(ABC): """ Abstract base class for registration. diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index 8c954fa03..f5ba370a1 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1063,6 +1063,48 @@ def try_weighted_mean(na): time.sleep(0.5) +# CGP<->DVF conversion +def try_cgp_dvf_conversion(na): + time.sleep(0.5) + sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') + sys.stderr.write('# Starting CGP<->DVF test...\n') + sys.stderr.write('# --------------------------------------------------------------------------------- #\n') + time.sleep(0.5) + + dvf = na.get_deformation_field_forward() + + # DVF->CPG + spacing = dvf.get_voxel_sizes()[1:4] * 2.0 + dvf_to_cpg = sirf.Reg.NiftiImageData3DBSpline(dvf, spacing) + + if abs(dvf_to_cpg.get_max()) < 1.e-4 or abs(dvf_to_cpg.get_min()) < 1.e-4: + raise AssertionError("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes.") + + # DVF->CPG->DVF + dvf_to_cpg_to_dvf = dvf_to_cpg.get_as_deformation_field(dvf.get_tensor_component(0)) + + # Compare + if dvf != dvf_to_cpg_to_dvf: + raise AssertionError("DVF->CPG->DVF != DVF.") + + # Do the same, using the converter + cpg_2_dvf_converter = sirf.Reg.ControlPointGridToDeformationConverter() + cpg_2_dvf_converter.set_cpg_spacing(spacing) + cpg_2_dvf_converter.set_reference_image(dvf.get_tensor_component(0)) + dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(dvf) + dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter) + + # Compare + if dvf_to_cpg_to_dvf != dvf_to_cpg_to_dvf_w_converter: + raise AssertionError("ControlPointGridToDeformationConverter DVF->CPG->DVF failed.") + + time.sleep(0.5) + sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') + sys.stderr.write('# Finished CGP<->DVF test.\n') + sys.stderr.write('# --------------------------------------------------------------------------------- #\n') + time.sleep(0.5) + + # AffineTransformation def try_affinetransformation(na): time.sleep(0.5) @@ -1221,6 +1263,7 @@ def test(): try_resample(na) try_niftymomo(na) try_weighted_mean(na) + try_cgp_dvf_conversion(na) try_affinetransformation(na) try_quaternion() From d97ec68558d30dcecdc9191f35e379ba6c55dc6f Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 2 Jul 2020 19:19:48 +0100 Subject: [PATCH 08/42] add adjoint test --- .gitignore | 2 ++ src/Registration/pReg/Reg.py.in | 33 ++++++++++++++++++++++++ src/Registration/pReg/tests/test_pReg.py | 5 ++++ 3 files changed, 40 insertions(+) diff --git a/.gitignore b/.gitignore index 6607b09d6..19cd698f2 100644 --- a/.gitignore +++ b/.gitignore @@ -281,3 +281,5 @@ coverage.xml # Ignore files created by VS Code .vscode/ + +results/ diff --git a/src/Registration/pReg/Reg.py.in b/src/Registration/pReg/Reg.py.in index b840ed0e9..2e62ad59a 100644 --- a/src/Registration/pReg/Reg.py.in +++ b/src/Registration/pReg/Reg.py.in @@ -293,6 +293,10 @@ class NiftiImageData(SIRF.ImageData): image = NiftiImageData3DDeformation() elif self.name == 'NiftiImageData3DDisplacement': image = NiftiImageData3DDisplacement() + elif self.name == 'NiftiImageData3DBSpline': + image = NiftiImageData3DBSpline() + else: + raise error("unknown object name: " + self.name) try_calling(pyreg.cReg_NiftiImageData_deep_copy(image.handle, self.handle)) return image @@ -683,6 +687,8 @@ class ControlPointGridToDeformationConverter(object): """ def __init__(self): self.handle = None + self.dvf_template = None # only used for testing + self.cpg_template = None # only used for testing self.name = 'ControlPointGridToDeformationConverter' self.handle = pyreg.cReg_newObject(self.name) check_status(self.handle) @@ -719,6 +725,33 @@ class ControlPointGridToDeformationConverter(object): check_status(output.handle) return output + def _set_up_for_adjoint_test(self, dvf_template, cpg_template): + """Set template dvf and cpg to be used for testing.""" + assert_validity(dvf_template, NiftiImageData3DDeformation) + assert_validity(cpg_template, NiftiImageData3DBSpline) + self.dvf_template = dvf_template + self.cpg_template = cpg_template + + def direct(self, cpg): + """Alias of forward.""" + return self.forward(cpg) + + def adjoint(self, dvf): + """Alias of backward.""" + return self.backward(dvf) + + def is_linear(self): + """Returns whether the transformation is linear""" + return True + + def domain_geometry(self): + """Get domain geometry (only used for testing).""" + return self.cpg_template + + def range_geometry(self): + """Get range geometry (only used for testing).""" + return self.dvf_template + class _Registration(ABC): """ diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index f5ba370a1..d64b20810 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1098,6 +1098,11 @@ def try_cgp_dvf_conversion(na): if dvf_to_cpg_to_dvf != dvf_to_cpg_to_dvf_w_converter: raise AssertionError("ControlPointGridToDeformationConverter DVF->CPG->DVF failed.") + # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon + cpg_2_dvf_converter._set_up_for_adjoint_test(dvf, dvf_to_cpg) + if not is_operator_adjoint(cpg_2_dvf_converter): + raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") + time.sleep(0.5) sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') sys.stderr.write('# Finished CGP<->DVF test.\n') From 8e328dc993141e084aef550a99d522645164d68d Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 3 Jul 2020 12:09:32 +0100 Subject: [PATCH 09/42] remove bspline from deformation constructor --- ...ControlPointGridToDeformationConverter.cpp | 15 ++++++++++++++- .../cReg/NiftiImageData3DBSpline.cpp | 19 ------------------- src/Registration/cReg/cReg.cpp | 14 -------------- .../sirf/Reg/NiftiImageData3DBSpline.h | 3 --- src/Registration/cReg/include/sirf/Reg/cReg.h | 3 --- src/Registration/cReg/tests/test_cReg.cpp | 19 ++++--------------- src/Registration/pReg/Reg.py.in | 8 -------- src/Registration/pReg/tests/test_pReg.py | 19 ++++--------------- 8 files changed, 22 insertions(+), 78 deletions(-) diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index 024d2e762..d66ce4840 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -30,6 +30,7 @@ limitations under the License. #include "sirf/Reg/ControlPointGridToDeformationConverter.h" #include "sirf/Reg/NiftiImageData3DDeformation.h" #include "sirf/Reg/NiftiImageData3DBSpline.h" +#include "sirf/NiftyMoMo/BSplineTransformation.h" using namespace sirf; @@ -73,7 +74,19 @@ ControlPointGridToDeformationConverter:: backward(const NiftiImageData3DDeformation &dvf) { check_is_set_up(); - return NiftiImageData3DBSpline(dvf, _spacing); + // not marked const, so copy + float spacing_nonconst[3] = {_spacing[0], _spacing[1], _spacing[2]}; + // Get any of the tensor components as a 3d image + auto ref_sptr = dvf.get_tensor_component(0); + nifti_image *ref_ptr = ref_sptr->get_raw_nifti_sptr().get(); + // Create the NiftyMoMo bspline transformation class + NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing_nonconst); + // Get cpg_ptr + nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); + // Convert DVF to CPG + cpg_ptr->data = bspline.GetDVFGradientWRTTransformationParameters(dvf.clone()->get_raw_nifti_sptr().get(), ref_ptr); + cpg_ptr->intent_p1 = SPLINE_VEL_GRID; + return NiftiImageData3DBSpline(*cpg_ptr); } template diff --git a/src/Registration/cReg/NiftiImageData3DBSpline.cpp b/src/Registration/cReg/NiftiImageData3DBSpline.cpp index c43fb624c..7091149f3 100644 --- a/src/Registration/cReg/NiftiImageData3DBSpline.cpp +++ b/src/Registration/cReg/NiftiImageData3DBSpline.cpp @@ -33,25 +33,6 @@ limitations under the License. using namespace sirf; -template -NiftiImageData3DBSpline::NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, const float spacing[]) -{ - // not marked const, so copy - float spacing_nonconst[3] = {spacing[0], spacing[1], spacing[2]}; - // Get any of the tensor components as a 3d image - auto ref_sptr = def.get_tensor_component(0); - nifti_image *ref_ptr = ref_sptr->get_raw_nifti_sptr().get(); - // Create the NiftyMoMo bspline transformation class - NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing_nonconst); - // Get cpg_ptr - nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); - // Convert DVF to CPG - cpg_ptr->data = bspline.GetDVFGradientWRTTransformationParameters(def.clone()->get_raw_nifti_sptr().get(), ref_ptr); - cpg_ptr->intent_p1 = SPLINE_VEL_GRID; - *this = NiftiImageData3DBSpline(*cpg_ptr); - this->check_dimensions(NiftiImageData::_3DBSpl); -} - template void NiftiImageData3DBSpline::create_from_3D_image(const NiftiImageData &image) { diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index ba9822161..ef696c92a 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -638,20 +638,6 @@ void* cReg_NiftiImageData3DDisplacement_create_from_def(const void* def_ptr) CATCH; } -// -------------------------------------------------------------------------------- // -// NiftiImageData3DBSpline -// -------------------------------------------------------------------------------- // -extern "C" -void* cReg_NiftiImageData3DBSpline_create_from_def(const void* def_ptr, const float spacing_x, const float spacing_y, const float spacing_z) -{ - try { - NiftiImageData3DDeformation& def = objectFromHandle >(def_ptr); - const float spacing[3] = {spacing_x, spacing_y, spacing_z}; - return newObjectHandle(std::make_shared >(def, spacing)); - } - CATCH; -} - // -------------------------------------------------------------------------------- // // ControlPointGridToDeformationConverter // -------------------------------------------------------------------------------- // diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h index f41e0f978..46f71a128 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DBSpline.h @@ -69,9 +69,6 @@ class NiftiImageData3DBSpline : public NiftiImageData3DTensor, public NiftiImageData3DBSpline(const NiftiImageData3D &x, const NiftiImageData3D &y, const NiftiImageData3D &z) : NiftiImageData3DTensor(x,y,z) { this->check_dimensions(this->_3DBSpl); } - /// Create from deformation field image - NiftiImageData3DBSpline(const NiftiImageData3DDeformation &def, const float spacing[]); - /// Create from 3D image void create_from_3D_image(const NiftiImageData &image); diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 931a7ff0f..010468f5b 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -81,9 +81,6 @@ extern "C" { // NiftiImageData3DDisplacement void* cReg_NiftiImageData3DDisplacement_create_from_def(const void* def_ptr); - // NiftiImageData3DBSpline - void* cReg_NiftiImageData3DBSpline_create_from_def(const void* def_ptr, const float spacing_x, const float spacing_y, const float spacing_z); - // ControlPointGridToDeformationConverter void* cReg_CPG2DVF_set_cpg_spacing(const void* converter_ptr, const float spacing_x, const float spacing_y, const float spacing_z); void* cReg_CPG2DVF_set_ref_im(const void* converter_ptr, const void* ref_im_ptr); diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index aa02633c0..c2de2710d 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1159,29 +1159,18 @@ int main(int argc, char* argv[]) float spacing[3]; for (unsigned i=0; i<3; ++i) spacing[i] = dvf_sptr->get_raw_nifti_sptr()->pixdim[i+1] * 2.f; - NiftiImageData3DBSpline dvf_to_cpg(*dvf_sptr, spacing); - if (std::abs(dvf_to_cpg.get_max()) < 1.e-4f || std::abs(dvf_to_cpg.get_min()) < 1.e-4f) - throw std::runtime_error("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes."); - // DVF->CPG->DVF - auto dvf_to_cpg_to_dvf = dvf_to_cpg.get_as_deformation_field(*dvf_sptr->get_tensor_component(0)); - - NiftiImageData::print_headers({ref_aladin.get(), dvf_sptr.get(), - &dvf_to_cpg, &dvf_to_cpg_to_dvf}); - - // Compare - if (*dvf_sptr != dvf_to_cpg_to_dvf) - throw std::runtime_error("DVF->CPG->DVF != DVF."); - - // Do the same, using the converter + // DVF->CPG with converter ControlPointGridToDeformationConverter cpg_2_dvf_converter; cpg_2_dvf_converter.set_cpg_spacing(spacing); cpg_2_dvf_converter.set_reference_image(*dvf_sptr->get_tensor_component(0)); + // DVF->CPG auto dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(*dvf_sptr); + // DVF->CPG->DVF auto dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter); // Compare - if (dvf_to_cpg_to_dvf != dvf_to_cpg_to_dvf_w_converter) + if (*dvf_sptr != dvf_to_cpg_to_dvf_w_converter) throw std::runtime_error("ControlPointGridToDeformationConverter DVF->CPG->DVF failed."); std::cout << "// ----------------------------------------------------------------------- //\n"; diff --git a/src/Registration/pReg/Reg.py.in b/src/Registration/pReg/Reg.py.in index 2e62ad59a..16d32be53 100644 --- a/src/Registration/pReg/Reg.py.in +++ b/src/Registration/pReg/Reg.py.in @@ -663,14 +663,6 @@ class NiftiImageData3DBSpline(NiftiImageData3DTensor, _Transformation): self.handle = pyreg.\ cReg_NiftiImageData3DTensor_construct_from_3_components( self.name, src1.handle, src2.handle, src3.handle) - # from deformation - elif isinstance(src1, NiftiImageData3DDeformation) and len(src2) == 3: - spacing = src2 - self.handle = pyreg.\ - cReg_NiftiImageData3DBSpline_create_from_def(src1.handle, - float(spacing[0]), - float(spacing[1]), - float(spacing[2])) else: raise error('Wrong source in NiftiImageData3DBSpline constructor') check_status(self.handle) diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index d64b20810..ee7833c48 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1072,30 +1072,19 @@ def try_cgp_dvf_conversion(na): time.sleep(0.5) dvf = na.get_deformation_field_forward() - - # DVF->CPG spacing = dvf.get_voxel_sizes()[1:4] * 2.0 - dvf_to_cpg = sirf.Reg.NiftiImageData3DBSpline(dvf, spacing) - - if abs(dvf_to_cpg.get_max()) < 1.e-4 or abs(dvf_to_cpg.get_min()) < 1.e-4: - raise AssertionError("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes.") - - # DVF->CPG->DVF - dvf_to_cpg_to_dvf = dvf_to_cpg.get_as_deformation_field(dvf.get_tensor_component(0)) - # Compare - if dvf != dvf_to_cpg_to_dvf: - raise AssertionError("DVF->CPG->DVF != DVF.") - - # Do the same, using the converter + # DVF->CPG with converter cpg_2_dvf_converter = sirf.Reg.ControlPointGridToDeformationConverter() cpg_2_dvf_converter.set_cpg_spacing(spacing) cpg_2_dvf_converter.set_reference_image(dvf.get_tensor_component(0)) + # DVF->CPG dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(dvf) + # DVF->CPG->DVF dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter) # Compare - if dvf_to_cpg_to_dvf != dvf_to_cpg_to_dvf_w_converter: + if dvf != dvf_to_cpg_to_dvf_w_converter: raise AssertionError("ControlPointGridToDeformationConverter DVF->CPG->DVF failed.") # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon From e3763671ed522bda2ef5a5dc0344c625773b739d Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 3 Jul 2020 15:55:51 +0100 Subject: [PATCH 10/42] updated --- ...ControlPointGridToDeformationConverter.cpp | 3 +-- .../ControlPointGridToDeformationConverter.h | 2 +- src/Registration/cReg/tests/test_cReg.cpp | 14 +++++++++---- src/Registration/pReg/tests/test_pReg.py | 21 ++++++++++++++++--- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index d66ce4840..469bac122 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -77,8 +77,7 @@ backward(const NiftiImageData3DDeformation &dvf) // not marked const, so copy float spacing_nonconst[3] = {_spacing[0], _spacing[1], _spacing[2]}; // Get any of the tensor components as a 3d image - auto ref_sptr = dvf.get_tensor_component(0); - nifti_image *ref_ptr = ref_sptr->get_raw_nifti_sptr().get(); + nifti_image *ref_ptr = _template_ref_sptr->get_raw_nifti_sptr().get(); // Create the NiftyMoMo bspline transformation class NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing_nonconst); // Get cpg_ptr diff --git a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h index 729385175..5c79182fd 100644 --- a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h +++ b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h @@ -71,6 +71,6 @@ class ControlPointGridToDeformationConverter void check_is_set_up() const; float _spacing[3]; - std::shared_ptr > _template_ref_sptr; + std::shared_ptr > _template_ref_sptr; }; } diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index c2de2710d..cd61a4c12 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1163,14 +1163,20 @@ int main(int argc, char* argv[]) // DVF->CPG with converter ControlPointGridToDeformationConverter cpg_2_dvf_converter; cpg_2_dvf_converter.set_cpg_spacing(spacing); - cpg_2_dvf_converter.set_reference_image(*dvf_sptr->get_tensor_component(0)); + cpg_2_dvf_converter.set_reference_image(*ref_sptr); // DVF->CPG - auto dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(*dvf_sptr); + auto dvf_to_cpg = cpg_2_dvf_converter.backward(*dvf_sptr); + + // Check CPG contains non-zeroes + if (std::abs(dvf_to_cpg.get_max()) < 1.e-4f || std::abs(dvf_to_cpg.get_min()) < 1.e-4f) + throw std::runtime_error("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes."); + // DVF->CPG->DVF - auto dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter); + auto dvf_to_cpg_to_dvf = cpg_2_dvf_converter.forward(dvf_to_cpg); + NiftiImageData::print_headers({dvf_sptr.get(), &dvf_to_cpg, &dvf_to_cpg_to_dvf}); // Compare - if (*dvf_sptr != dvf_to_cpg_to_dvf_w_converter) + if (*dvf_sptr != dvf_to_cpg_to_dvf) throw std::runtime_error("ControlPointGridToDeformationConverter DVF->CPG->DVF failed."); std::cout << "// ----------------------------------------------------------------------- //\n"; diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index ee7833c48..7e46f6483 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1079,12 +1079,12 @@ def try_cgp_dvf_conversion(na): cpg_2_dvf_converter.set_cpg_spacing(spacing) cpg_2_dvf_converter.set_reference_image(dvf.get_tensor_component(0)) # DVF->CPG - dvf_to_cpg_w_converter = cpg_2_dvf_converter.backward(dvf) + dvf_to_cpg = cpg_2_dvf_converter.backward(dvf) # DVF->CPG->DVF - dvf_to_cpg_to_dvf_w_converter = cpg_2_dvf_converter.forward(dvf_to_cpg_w_converter) + dvf_to_cpg_to_dvf = cpg_2_dvf_converter.forward(dvf_to_cpg) # Compare - if dvf != dvf_to_cpg_to_dvf_w_converter: + if dvf != dvf_to_cpg_to_dvf: raise AssertionError("ControlPointGridToDeformationConverter DVF->CPG->DVF failed.") # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon @@ -1092,6 +1092,21 @@ def try_cgp_dvf_conversion(na): if not is_operator_adjoint(cpg_2_dvf_converter): raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") + x = dvf_to_cpg + # y = na.get_deformation_field_inverse() + y = sirf.Reg.NiftiImageData3DDeformation(aladin_def_inverse) + y_hat = cpg_2_dvf_converter.forward(x) + x_hat = cpg_2_dvf_converter.backward(y) + y_dot = y_hat.dot(y) + x_dot = x_hat.dot(x) + diff = abs(y_dot - x_dot) + avg = 0.5 * (abs(y_dot) + abs(x_dot)) + + norm_err = diff/avg + max_err = 10e-5 + if norm_err > max_err: + raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") + time.sleep(0.5) sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') sys.stderr.write('# Finished CGP<->DVF test.\n') From 60ededc75885ce634e274c3a536e4c659d187b21 Mon Sep 17 00:00:00 2001 From: richard Date: Sat, 4 Jul 2020 17:39:16 +0100 Subject: [PATCH 11/42] c++ test works --- .../NiftyMoMo/BSplineTransformation.cpp | 2 +- ...ControlPointGridToDeformationConverter.cpp | 9 +- src/Registration/cReg/NiftiImageData.cpp | 2 + .../ControlPointGridToDeformationConverter.h | 4 +- src/Registration/cReg/tests/test_cReg.cpp | 125 ++++++++++++++---- 5 files changed, 110 insertions(+), 32 deletions(-) diff --git a/src/Registration/NiftyMoMo/BSplineTransformation.cpp b/src/Registration/NiftyMoMo/BSplineTransformation.cpp index b0c3265d0..a47d9b5d2 100644 --- a/src/Registration/NiftyMoMo/BSplineTransformation.cpp +++ b/src/Registration/NiftyMoMo/BSplineTransformation.cpp @@ -1373,7 +1373,7 @@ BSplineTransformation::PrecisionType* BSplineTransformation::GetDVFGradientWRTTr // Note: Performing the reorientation here is way more efficient, since // only the transformation parameters need to be touched (and not) // the complete DVF - this->ReorientateVectorImage( outDVFGradWRTTrafoParams, sourceImage->sto_ijk ); +// this->ReorientateVectorImage( outDVFGradWRTTrafoParams, sourceImage->sto_ijk ); // Copy over the data pointer from the the image and detach it. Then delete the image. PrecisionType* outDVFGradWRTTrafoParamData = (PrecisionType*) outDVFGradWRTTrafoParams->data; diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index 469bac122..058907311 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -62,21 +62,24 @@ set_reference_image(const NiftiImageData &ref) template NiftiImageData3DDeformation ControlPointGridToDeformationConverter:: -forward(const NiftiImageData3DBSpline &cpg) +forward(const NiftiImageData3DBSpline &cpg) const { check_is_set_up(); +// NiftiImageData3DDeformation dvf; +// dvf.create_from_cpp(cpg, *_template_ref_sptr); +// return dvf; return cpg.get_as_deformation_field(*_template_ref_sptr); } template NiftiImageData3DBSpline ControlPointGridToDeformationConverter:: -backward(const NiftiImageData3DDeformation &dvf) +backward(const NiftiImageData3DDeformation &dvf) const { check_is_set_up(); // not marked const, so copy float spacing_nonconst[3] = {_spacing[0], _spacing[1], _spacing[2]}; - // Get any of the tensor components as a 3d image + // Get raw nifti_image from reference image nifti_image *ref_ptr = _template_ref_sptr->get_raw_nifti_sptr().get(); // Create the NiftyMoMo bspline transformation class NiftyMoMo::BSplineTransformation bspline(ref_ptr, 1, spacing_nonconst); diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index a5af16718..46752c80f 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -523,11 +523,13 @@ void NiftiImageData::check_dimensions(const NiftiImageDataType image_t else if (intent_code == NIFTI_INTENT_VECTOR) ss << ", intent_code = Vector"; if (intent_p1 == 0) ss << ", intent_p1 = Deformation"; else if (intent_p1 == 1) ss << ", intent_p1 = Displacement"; + else if (intent_p1 == SPLINE_VEL_GRID) ss << ", intent_p1 = Control point grid"; ss << "\n\t\tActual params: ndim = " << _nifti_image->ndim << ", nu = " << _nifti_image->nu << ", nt = " << _nifti_image->nt; if (_nifti_image->intent_code == NIFTI_INTENT_NONE) ss << ", intent_code = None"; else if (_nifti_image->intent_code == NIFTI_INTENT_VECTOR) ss << ", intent_code = Vector"; if (intent_p1 != -1 && _nifti_image->intent_p1 == 0) ss << ", intent_p1 = Deformation"; else if (intent_p1 != -1 && _nifti_image->intent_p1 == 1) ss << ", intent_p1 = Displacement"; + else if (intent_p1 != -1 && _nifti_image->intent_p1 == SPLINE_VEL_GRID) ss << ", intent_p1 = Control point grid"; //std::cout << ss.str() << "\n"; throw std::runtime_error(ss.str()); } diff --git a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h index 5c79182fd..464655a85 100644 --- a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h +++ b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h @@ -60,10 +60,10 @@ class ControlPointGridToDeformationConverter void set_reference_image(const NiftiImageData &ref); /// CPG to DVF - NiftiImageData3DDeformation forward(const NiftiImageData3DBSpline &cpg); + NiftiImageData3DDeformation forward(const NiftiImageData3DBSpline &cpg) const; /// DVF to CPG - NiftiImageData3DBSpline backward(const NiftiImageData3DDeformation &dvf); + NiftiImageData3DBSpline backward(const NiftiImageData3DDeformation &dvf) const; private: diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index cd61a4c12..fd01398b8 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -47,6 +47,54 @@ limitations under the License. using namespace sirf; + +void check_non_zero(const NiftiImageData &im, + const std::string &explanation) +{ + if (std::abs(im.get_min()) < 1e-4f && std::abs(im.get_max()) < 1e-4f) + throw std::runtime_error(explanation + ": contains no non-zeroes"); +} +NiftiImageData3DDeformation +CPG2DVF(const ControlPointGridToDeformationConverter &converter, + const NiftiImageData3DBSpline &cpg) +{ + check_non_zero(cpg, "converter::forward (input)"); + auto dvf = converter.forward(cpg); + check_non_zero(dvf, "converter::forward (output)"); + return dvf; +} +NiftiImageData3DBSpline +DVF2CPG(const ControlPointGridToDeformationConverter &converter, + const NiftiImageData3DDeformation &dvf) +{ + check_non_zero(dvf, "converter::backward (input)"); + auto cpg = converter.backward(dvf); + check_non_zero(cpg, "converter::backward (output)"); + return cpg; +} +NiftiImageData3DDeformation +rand_dvf( + NiftiImageData3DDisplacement &disp, + const float min_disp = -10.f, const float max_disp = 10.f) +{ + for (unsigned i=0; i(rand()) /(static_cast(RAND_MAX/(max_disp-min_disp))); + auto dvf = NiftiImageData3DDeformation(disp); + check_non_zero(dvf, "Rand DVF"); + return dvf; +} +NiftiImageData3DBSpline +rand_cpg( + const ControlPointGridToDeformationConverter &converter, + NiftiImageData3DDisplacement &disp, + const float min_disp = -10.f, const float max_disp = 10.f) +{ + auto dvf = rand_dvf(disp, min_disp, max_disp); + auto cpg = DVF2CPG(converter,dvf); + check_non_zero(cpg, "Rand CPG"); + return cpg; +} + int main(int argc, char* argv[]) { @@ -1152,32 +1200,57 @@ int main(int argc, char* argv[]) std::cout << "// Starting CGP<->DVF test...\n"; std::cout << "//------------------------------------------------------------------------ //\n"; - auto dvf_sptr = std::dynamic_pointer_cast >( - NA.get_deformation_field_forward_sptr()); - - // DVF->CPG - float spacing[3]; - for (unsigned i=0; i<3; ++i) - spacing[i] = dvf_sptr->get_raw_nifti_sptr()->pixdim[i+1] * 2.f; - - // DVF->CPG with converter - ControlPointGridToDeformationConverter cpg_2_dvf_converter; - cpg_2_dvf_converter.set_cpg_spacing(spacing); - cpg_2_dvf_converter.set_reference_image(*ref_sptr); - // DVF->CPG - auto dvf_to_cpg = cpg_2_dvf_converter.backward(*dvf_sptr); - - // Check CPG contains non-zeroes - if (std::abs(dvf_to_cpg.get_max()) < 1.e-4f || std::abs(dvf_to_cpg.get_min()) < 1.e-4f) - throw std::runtime_error("NiftiImageData3DBSpline::NiftiImageData3DBSpline(DVF): contains only zeroes."); - - // DVF->CPG->DVF - auto dvf_to_cpg_to_dvf = cpg_2_dvf_converter.forward(dvf_to_cpg); - - NiftiImageData::print_headers({dvf_sptr.get(), &dvf_to_cpg, &dvf_to_cpg_to_dvf}); - // Compare - if (*dvf_sptr != dvf_to_cpg_to_dvf) - throw std::runtime_error("ControlPointGridToDeformationConverter DVF->CPG->DVF failed."); + // Test both 2D and 3D cases + for (unsigned is_3d=0; is_3d<2; ++is_3d) { + unsigned int z_size = is_3d ? 32 : 1; + // Generate image + VoxelisedGeometricalInfo3D::Size size({150,125,z_size}); + VoxelisedGeometricalInfo3D::Spacing spacing_dvf({2.f,3.f,5.f}); + VoxelisedGeometricalInfo3D::Offset offset({0.f,0.f,0.f}); + std::array dm_row_1({1.f,0.f,0.f}); + std::array dm_row_2({0.f,1.f,0.f}); + std::array dm_row_3({0.f,0.f,1.f}); + VoxelisedGeometricalInfo3D::DirectionMatrix dm({dm_row_1,dm_row_2, dm_row_3}); + VoxelisedGeometricalInfo3D geom_info(offset, spacing_dvf, size, dm); + // Create displacement, convert to deformation and reference image + NiftiImageData3DDisplacement disp( + *NiftiImageData::create_from_geom_info( + geom_info,true, NREG_TRANS_TYPE::DISP_FIELD)); + NiftiImageData3DDeformation dvf(disp); + NiftiImageData ref = *dvf.get_tensor_component(0); + + // CPG spacing double the dvf spacing + float cpg_spacing[3] = {4.f * spacing_dvf[0], 4.f * spacing_dvf[1], 4.f * spacing_dvf[2]}; + + // set up DVF<->CPG converter + ControlPointGridToDeformationConverter cpg_2_dvf_converter; + cpg_2_dvf_converter.set_cpg_spacing(cpg_spacing); + cpg_2_dvf_converter.set_reference_image(ref); + + // ok, now ready to do adjoint test using: + // | - | / 0.5*(||+||) < epsilon + + for (unsigned i=0; i<10; ++i) { + // Get random CPG and DVF + auto x = rand_cpg(cpg_2_dvf_converter, disp); + auto y = rand_dvf(disp); + + // Convert random CPG to DVF and random DVF to CPG + auto Tx = CPG2DVF(cpg_2_dvf_converter,x); + auto Tsy = DVF2CPG(cpg_2_dvf_converter,y); + + // Get inner products + float x_dot, y_dot; + dynamic_cast(x).dot(Tsy, &x_dot); + dynamic_cast(y).dot(Tx, &y_dot); + + float adjoint_test = std::abs(x_dot - y_dot) / (0.5f * (std::abs(x_dot) + std::abs(y_dot))); + std::cout << "\t| - | / 0.5*(||+||) = " << adjoint_test << "\n"; + float epsilon = 1e-4f; + if (adjoint_test > epsilon) + throw std::runtime_error("adjoint test > " + std::to_string(epsilon)); + } + } std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Finished CGP<->DVF test.\n"; From 65a81c2a7ca4ecbc5f2e9d987948bdf84c5ac75d Mon Sep 17 00:00:00 2001 From: richard Date: Sat, 4 Jul 2020 17:51:12 +0100 Subject: [PATCH 12/42] create from cpp const --- src/Registration/cReg/NiftiImageData3DDeformation.cpp | 6 ++++-- .../cReg/include/sirf/Reg/NiftiImageData3DDeformation.h | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData3DDeformation.cpp b/src/Registration/cReg/NiftiImageData3DDeformation.cpp index 1dbe9eec9..68dd41fa9 100644 --- a/src/Registration/cReg/NiftiImageData3DDeformation.cpp +++ b/src/Registration/cReg/NiftiImageData3DDeformation.cpp @@ -54,11 +54,13 @@ void NiftiImageData3DDeformation::create_from_3D_image(const NiftiImag } template -void NiftiImageData3DDeformation::create_from_cpp(NiftiImageData3DTensor &cpp, const NiftiImageData &ref) +void NiftiImageData3DDeformation::create_from_cpp(const NiftiImageData3DTensor &cpp, const NiftiImageData &ref) { this->create_from_3D_image(ref); - reg_spline_getDeformationField(cpp.get_raw_nifti_sptr().get(), + auto cpp_clone = cpp.clone(); + + reg_spline_getDeformationField(cpp_clone->get_raw_nifti_sptr().get(), this->_nifti_image.get(), NULL, false, //composition diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DDeformation.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DDeformation.h index 874baa451..a1d84607a 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DDeformation.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DDeformation.h @@ -85,7 +85,7 @@ class NiftiImageData3DDeformation : public NiftiImageData3DTensor, pub void create_from_3D_image(const NiftiImageData &image); /// Create from control point grid image - void create_from_cpp(NiftiImageData3DTensor &cpp, const NiftiImageData &ref); + void create_from_cpp(const NiftiImageData3DTensor &cpp, const NiftiImageData &ref); /// Get as deformation field virtual NiftiImageData3DDeformation get_as_deformation_field(const NiftiImageData &ref) const; From 66cfe90e89766fc6e728e7bae8d3bb2c8640bf26 Mon Sep 17 00:00:00 2001 From: richard Date: Sat, 4 Jul 2020 18:16:49 +0100 Subject: [PATCH 13/42] python test passes --- src/Registration/pReg/tests/test_pReg.py | 48 +++++++----------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index 7e46f6483..94189663e 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1083,30 +1083,11 @@ def try_cgp_dvf_conversion(na): # DVF->CPG->DVF dvf_to_cpg_to_dvf = cpg_2_dvf_converter.forward(dvf_to_cpg) - # Compare - if dvf != dvf_to_cpg_to_dvf: - raise AssertionError("ControlPointGridToDeformationConverter DVF->CPG->DVF failed.") - # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon cpg_2_dvf_converter._set_up_for_adjoint_test(dvf, dvf_to_cpg) if not is_operator_adjoint(cpg_2_dvf_converter): raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") - x = dvf_to_cpg - # y = na.get_deformation_field_inverse() - y = sirf.Reg.NiftiImageData3DDeformation(aladin_def_inverse) - y_hat = cpg_2_dvf_converter.forward(x) - x_hat = cpg_2_dvf_converter.backward(y) - y_dot = y_hat.dot(y) - x_dot = x_hat.dot(x) - diff = abs(y_dot - x_dot) - avg = 0.5 * (abs(y_dot) + abs(x_dot)) - - norm_err = diff/avg - max_err = 10e-5 - if norm_err > max_err: - raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") - time.sleep(0.5) sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') sys.stderr.write('# Finished CGP<->DVF test.\n') @@ -1261,24 +1242,21 @@ def try_quaternion(): def test(): - try_niftiimage() - try_niftiimage3d() - try_niftiimage3dtensor() - try_niftiimage3ddisplacement() - try_niftiimage3ddeformation() + # try_niftiimage() + # try_niftiimage3d() + # try_niftiimage3dtensor() + # try_niftiimage3ddisplacement() + # try_niftiimage3ddeformation() na = try_niftyaladin() - try_niftyf3d() - try_transformations(na) - try_resample(na) - try_niftymomo(na) - try_weighted_mean(na) + # try_niftyf3d() + # try_transformations(na) + # try_resample(na) + # try_niftymomo(na) + # try_weighted_mean(na) try_cgp_dvf_conversion(na) - try_affinetransformation(na) - try_quaternion() + # try_affinetransformation(na) + # try_quaternion() if __name__ == "__main__": - try: - test() - except: - raise error("Error encountered.") + test() From ec8d6fb3de8cd99beef37c706372dc999c407d00 Mon Sep 17 00:00:00 2001 From: richard Date: Mon, 6 Jul 2020 13:03:34 +0100 Subject: [PATCH 14/42] codacy changes --- src/Registration/pReg/tests/test_pReg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index 94189663e..fc877d7c8 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -23,7 +23,7 @@ import numpy as np import nibabel as nib import sirf.Reg -from pUtilities import * +from sirf.Utilities import is_operator_adjoint # Paths SIRF_PATH = os.environ.get('SIRF_PATH') @@ -1081,7 +1081,7 @@ def try_cgp_dvf_conversion(na): # DVF->CPG dvf_to_cpg = cpg_2_dvf_converter.backward(dvf) # DVF->CPG->DVF - dvf_to_cpg_to_dvf = cpg_2_dvf_converter.forward(dvf_to_cpg) + _ = cpg_2_dvf_converter.forward(dvf_to_cpg) # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon cpg_2_dvf_converter._set_up_for_adjoint_test(dvf, dvf_to_cpg) From 30691c02437125eb04e809a7cd7050d770f99df1 Mon Sep 17 00:00:00 2001 From: richard Date: Tue, 7 Jul 2020 17:10:52 +0100 Subject: [PATCH 15/42] image gradient wrt transformation --- src/Registration/cReg/CMakeLists.txt | 1 + .../cReg/ImageGradientWRTTransformation.cpp | 61 ++++++++++++++++ src/Registration/cReg/NiftyResample.cpp | 55 +++++++++++++++ .../sirf/Reg/ImageGradientWRTTransformation.h | 69 +++++++++++++++++++ .../cReg/include/sirf/Reg/NiftyResample.h | 6 ++ .../cReg/include/sirf/Reg/Resample.h | 6 ++ 6 files changed, 198 insertions(+) create mode 100644 src/Registration/cReg/ImageGradientWRTTransformation.cpp create mode 100644 src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h diff --git a/src/Registration/cReg/CMakeLists.txt b/src/Registration/cReg/CMakeLists.txt index d8d4069ca..70a0d37cf 100644 --- a/src/Registration/cReg/CMakeLists.txt +++ b/src/Registration/cReg/CMakeLists.txt @@ -46,6 +46,7 @@ SET(SOURCES "NiftiImageData3DDisplacement.cpp" "NiftiImageData3DBSpline.cpp" "ControlPointGridToDeformationConverter.cpp" + "ImageGradientWRTTransformation.cpp" ) # If we're also wrapping to python or matlab, include the c-files diff --git a/src/Registration/cReg/ImageGradientWRTTransformation.cpp b/src/Registration/cReg/ImageGradientWRTTransformation.cpp new file mode 100644 index 000000000..7416bfc7b --- /dev/null +++ b/src/Registration/cReg/ImageGradientWRTTransformation.cpp @@ -0,0 +1,61 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for getting image gradient WRT a transformation. + +\author Richard Brown +\author SyneRBI +*/ + +#include "sirf/Reg/ImageGradientWRTTransformation.h" +#include "sirf/Reg/Resample.h" + +using namespace sirf; + +template +void +ImageGradientWRTTransformation:: +set_resampler(const std::shared_ptr > resampler_sptr) +{ + _resampler_sptr = resampler_sptr; +} + +template +void +ImageGradientWRTTransformation:: +forward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr) +{ + _resampler_sptr->get_image_gradient_wrt_transformation(output_transformation_sptr, source_im_sptr); +} + +template +std::shared_ptr > +ImageGradientWRTTransformation:: +forward(const std::shared_ptr source_im_sptr) +{ + return std::move(_resampler_sptr->get_image_gradient_wrt_transformation(source_im_sptr)); +} + +namespace sirf { +template class ImageGradientWRTTransformation; +} diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index fbfe3ef51..d09f6087c 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -28,6 +28,7 @@ limitations under the License. */ #include "sirf/Reg/NiftyResample.h" +#include "sirf/Reg/NiftiImageData3D.h" #include "sirf/Reg/NiftiImageData3DTensor.h" #include "sirf/Reg/NiftiImageData3DDeformation.h" #include "sirf/Reg/AffineTransformation.h" @@ -338,6 +339,60 @@ void NiftyResample::adjoint(std::shared_ptr output_sptr, co set_post_resample_outputs(output_sptr, this->_output_image_sptr, _output_image_adjoint_niftis); } +template +static +void convert_to_NiftiImageData_if_not_already(std::shared_ptr > &output_sptr, const std::shared_ptr &input_sptr) +{ + // Try to dynamic cast from ImageData to (const) NiftiImageData. This will only succeed if original type was NiftiImageData + output_sptr = std::dynamic_pointer_cast >(input_sptr); + // If output is a null pointer, it means that a different image type was supplied (e.g., STIRImageData). + // In this case, construct a NiftiImageData + if (!output_sptr) + output_sptr = std::make_shared >(*input_sptr); +} + +template +void +NiftyResample:: +get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, + const std::shared_ptr source_im_sptr) +{ + // Call the set up + set_up(); + + auto output_deformation_sptr = + std::dynamic_pointer_cast >(output_transformation_sptr); + + // Convert image to NiftiImageData if not already + std::shared_ptr > im_nii_sptr; + convert_to_NiftiImageData_if_not_already(im_nii_sptr, source_im_sptr->clone()); + + // Get raw nifti_image pointer + nifti_image * source_nii_ptr = im_nii_sptr->get_raw_nifti_sptr().get(); + + // Get image gradient + reg_getImageGradient(source_nii_ptr, + output_deformation_sptr->get_raw_nifti_sptr().get(), + this->_deformation_sptr->get_raw_nifti_sptr().get(), + nullptr, + this->_interpolation_type, + this->_padding_value, + 0); +} + +template +std::shared_ptr > +NiftyResample:: +get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr) +{ + // Call the set up + set_up(); + + std::shared_ptr > output_deformation_sptr = this->_deformation_sptr->clone(); + get_image_gradient_wrt_transformation(output_deformation_sptr, source_im_sptr); + return std::move(output_deformation_sptr); +} + namespace sirf { template class NiftyResample; } diff --git a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h new file mode 100644 index 000000000..399ee94cd --- /dev/null +++ b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h @@ -0,0 +1,69 @@ +/* +SyneRBI Synergistic Image Reconstruction Framework (SIRF) +Copyright 2020 University College London + +This is software developed for the Collaborative Computational +Project in Synergistic Reconstruction for Biomedical Imaging (formerly CCP PETMR) +(http://www.ccpsynerbi.ac.uk/). + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +/*! +\file +\ingroup Registration +\brief Class for getting image gradient WRT a transformation. + +\author Richard Brown +\author SyneRBI +*/ + +#pragma once + +#include + +namespace sirf { + +template class Resample; +template class Transformation; +class ImageData; + +/*! +\ingroup Registration +\brief Class for converting control point grids to deformation field transformations. + +\author Richard Brown +\author SyneRBI +*/ +template +class ImageGradientWRTTransformation +{ +public: + + /// Constructor + ImageGradientWRTTransformation(); + + /// Set the resampler + void set_resampler(const std::shared_ptr > resampler_sptr); + + /// Forward in place (get image gradient wrt transformation) + virtual void forward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr); + + /// Forward (get image gradient wrt transformation) + virtual std::shared_ptr > forward(const std::shared_ptr source_im_sptr); + +private: + + /// Resampler + std::shared_ptr > _resampler_sptr; +}; +} diff --git a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h index c79043d43..8e3428bef 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h @@ -123,6 +123,12 @@ class NiftyResample : public Resample /// Do the adjoint transformation virtual void adjoint(std::shared_ptr output_sptr, const std::shared_ptr input_sptr); + /// Get image gradient wrt transformation in place + virtual void get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr); + + /// Get image gradient wrt transformation + virtual std::shared_ptr > get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr); + protected: /// Set up diff --git a/src/Registration/cReg/include/sirf/Reg/Resample.h b/src/Registration/cReg/include/sirf/Reg/Resample.h index b85b28e3c..c9f9ee210 100644 --- a/src/Registration/cReg/include/sirf/Reg/Resample.h +++ b/src/Registration/cReg/include/sirf/Reg/Resample.h @@ -128,6 +128,12 @@ class Resample /// Backward. Alias for Adjoint virtual void backward(std::shared_ptr output_sptr, const std::shared_ptr input_sptr); + /// Get image gradient wrt transformation in place + virtual void get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr) = 0; + + /// Get image gradient wrt transformation + virtual std::shared_ptr > get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr) = 0; + protected: /// Set up From bf448e5935ed20fbee3bb5530457a3a2db32d588 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 10:50:26 +0100 Subject: [PATCH 16/42] multiply tensor component by scalar image --- .../cReg/NiftiImageData3DTensor.cpp | 32 +++++++++++++++++++ .../include/sirf/Reg/NiftiImageData3DTensor.h | 3 ++ 2 files changed, 35 insertions(+) diff --git a/src/Registration/cReg/NiftiImageData3DTensor.cpp b/src/Registration/cReg/NiftiImageData3DTensor.cpp index aec54adb1..92341d161 100644 --- a/src/Registration/cReg/NiftiImageData3DTensor.cpp +++ b/src/Registration/cReg/NiftiImageData3DTensor.cpp @@ -177,6 +177,38 @@ void NiftiImageData3DTensor::flip_component(const int dim) this->_data[i] = -this->_data[i]; } +template +void NiftiImageData3DTensor:: +multiply_tensor_component +(const int dim, const std::shared_ptr &scalar_im_sptr) +{ + // Check the dimension to multiply, that dims==5 and nu==3 + if (dim < 0 || dim > 2) + throw std::runtime_error("\n\tDimension to multiply should be between 0 and 2."); + + std::shared_ptr > nii_scalar_im_sptr = + std::dynamic_pointer_cast >(scalar_im_sptr); + if (!nii_scalar_im_sptr) + nii_scalar_im_sptr = std::make_shared >(*scalar_im_sptr); + + // Check dimensions match (except the tensor component, obviously) + const int *tensor_dims = this->get_dimensions(); + const int *scalar_dims = nii_scalar_im_sptr->get_dimensions(); + for (unsigned i=1; i<7; ++i) { + // skip tensor component (u, which is 5) + if (i!=5 && tensor_dims[i] != scalar_dims[i]) + throw std::runtime_error("NiftiImageData3DTensor::multiply_tensor_component mismatch in image sizes"); + } + + // Data is ordered such that the multicomponent info is last. + // So, the first third of the data is the x-values, second third is y and last third is z. + // Start index is therefore = dim_number * num_voxels/3 + const unsigned tensor_index_offset = dim * int(this->_nifti_image->nvox/3); + + for (unsigned i=0; iget_num_voxels(); ++i) + (*this)(i+tensor_index_offset) *= (*nii_scalar_im_sptr)(i); +} + namespace sirf { template class NiftiImageData3DTensor; } diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h index bd19b14ff..270ebfe3f 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h @@ -90,6 +90,9 @@ class NiftiImageData3DTensor : public NiftiImageData /// Flip component of nu void flip_component(const int dim); + /// Multiply tensor component by image + void multiply_tensor_component(const int dim, const std::shared_ptr &scalar_im_sptr); + virtual ObjectHandle* new_data_container_handle() const { return new ObjectHandle From bab142e3188815c291fbbfe746798b4e50e9bd07 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 10:50:45 +0100 Subject: [PATCH 17/42] add forward method --- .../cReg/ImageGradientWRTTransformation.cpp | 38 +++++++++++++++---- src/Registration/cReg/NiftyResample.cpp | 30 ++++++++------- .../sirf/Reg/ImageGradientWRTTransformation.h | 22 +++++++---- .../cReg/include/sirf/Reg/NiftyResample.h | 4 +- .../cReg/include/sirf/Reg/Resample.h | 6 --- 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/src/Registration/cReg/ImageGradientWRTTransformation.cpp b/src/Registration/cReg/ImageGradientWRTTransformation.cpp index 7416bfc7b..271a8b6ae 100644 --- a/src/Registration/cReg/ImageGradientWRTTransformation.cpp +++ b/src/Registration/cReg/ImageGradientWRTTransformation.cpp @@ -28,14 +28,15 @@ limitations under the License. */ #include "sirf/Reg/ImageGradientWRTTransformation.h" -#include "sirf/Reg/Resample.h" +#include "sirf/Reg/NiftyResample.h" +#include "sirf/Reg/NiftiImageData3DDeformation.h" using namespace sirf; template void ImageGradientWRTTransformation:: -set_resampler(const std::shared_ptr > resampler_sptr) +set_resampler(const std::shared_ptr > resampler_sptr) { _resampler_sptr = resampler_sptr; } @@ -43,17 +44,40 @@ set_resampler(const std::shared_ptr > resampler_sptr) template void ImageGradientWRTTransformation:: -forward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr) +forward(std::shared_ptr &im_sptr, + const std::shared_ptr > &deformation_sptr) { - _resampler_sptr->get_image_gradient_wrt_transformation(output_transformation_sptr, source_im_sptr); + _resampler_sptr->clear_transformations(); + _resampler_sptr->add_transformation(deformation_sptr); + _resampler_sptr->process(); + im_sptr->fill(*_resampler_sptr->get_output_sptr()); } template -std::shared_ptr > +std::shared_ptr ImageGradientWRTTransformation:: -forward(const std::shared_ptr source_im_sptr) +forward(const std::shared_ptr > deformation_sptr) { - return std::move(_resampler_sptr->get_image_gradient_wrt_transformation(source_im_sptr)); + _resampler_sptr->clear_transformations(); + _resampler_sptr->add_transformation(deformation_sptr); + _resampler_sptr->process(); + return _resampler_sptr->get_output_sptr(); +} + +template +void +ImageGradientWRTTransformation:: +backward(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr) +{ + _resampler_sptr->get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, image_to_multiply_sptr); +} + +template +std::shared_ptr > +ImageGradientWRTTransformation:: +backward(const std::shared_ptr image_to_multiply_sptr) +{ + return std::move(_resampler_sptr->get_image_gradient_wrt_deformation_times_image(image_to_multiply_sptr)); } namespace sirf { diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index d09f6087c..fc61f019c 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -354,42 +354,44 @@ void convert_to_NiftiImageData_if_not_already(std::shared_ptr void NiftyResample:: -get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, - const std::shared_ptr source_im_sptr) +get_image_gradient_wrt_deformation_times_image( + std::shared_ptr > &output_deformation_sptr, + const std::shared_ptr image_to_multiply_sptr) { // Call the set up set_up(); - auto output_deformation_sptr = - std::dynamic_pointer_cast >(output_transformation_sptr); - - // Convert image to NiftiImageData if not already - std::shared_ptr > im_nii_sptr; - convert_to_NiftiImageData_if_not_already(im_nii_sptr, source_im_sptr->clone()); + // Not implemented for complex images + if (this->_floating_image_niftis.is_complex() || image_to_multiply_sptr->is_complex()) + throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image not yet implemented for complex images"); // Get raw nifti_image pointer - nifti_image * source_nii_ptr = im_nii_sptr->get_raw_nifti_sptr().get(); + nifti_image * floating_nii_ptr = this->_floating_image_niftis.real()->clone()->get_raw_nifti_sptr().get(); // Get image gradient - reg_getImageGradient(source_nii_ptr, + reg_getImageGradient(floating_nii_ptr, output_deformation_sptr->get_raw_nifti_sptr().get(), this->_deformation_sptr->get_raw_nifti_sptr().get(), nullptr, this->_interpolation_type, this->_padding_value, 0); + + // Now multiply the scalar image to each of the DVF components + for (unsigned i=0; i<3; ++i) + output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); } template -std::shared_ptr > +std::shared_ptr > NiftyResample:: -get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr) +get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr) { // Call the set up set_up(); - std::shared_ptr > output_deformation_sptr = this->_deformation_sptr->clone(); - get_image_gradient_wrt_transformation(output_deformation_sptr, source_im_sptr); + std::shared_ptr > output_deformation_sptr = this->_deformation_sptr->clone(); + get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, image_to_multiply_sptr); return std::move(output_deformation_sptr); } diff --git a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h index 399ee94cd..3fb3a9b1b 100644 --- a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h +++ b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h @@ -33,8 +33,8 @@ limitations under the License. namespace sirf { -template class Resample; -template class Transformation; +template class NiftyResample; +template class NiftiImageData3DDeformation; class ImageData; /*! @@ -53,17 +53,23 @@ class ImageGradientWRTTransformation ImageGradientWRTTransformation(); /// Set the resampler - void set_resampler(const std::shared_ptr > resampler_sptr); + void set_resampler(const std::shared_ptr > resampler_sptr); - /// Forward in place (get image gradient wrt transformation) - virtual void forward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr); + /// Forward in place (resample image) + virtual void forward(std::shared_ptr &im_sptr, const std::shared_ptr > &deformation_sptr); - /// Forward (get image gradient wrt transformation) - virtual std::shared_ptr > forward(const std::shared_ptr source_im_sptr); + /// Forward (resample image) + virtual std::shared_ptr forward(const std::shared_ptr > deformation_sptr); + + /// Backward in place (get image gradient wrt transformation) + virtual void backward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr image_to_multiply_sptr); + + /// Backward (get image gradient wrt transformation) + virtual std::shared_ptr > backward(const std::shared_ptr image_to_multiply_sptr); private: /// Resampler - std::shared_ptr > _resampler_sptr; + std::shared_ptr > _resampler_sptr; }; } diff --git a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h index 8e3428bef..5e1e8f405 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h @@ -124,10 +124,10 @@ class NiftyResample : public Resample virtual void adjoint(std::shared_ptr output_sptr, const std::shared_ptr input_sptr); /// Get image gradient wrt transformation in place - virtual void get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr); + virtual void get_image_gradient_wrt_deformation_times_image(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr); /// Get image gradient wrt transformation - virtual std::shared_ptr > get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr); + virtual std::shared_ptr > get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr); protected: diff --git a/src/Registration/cReg/include/sirf/Reg/Resample.h b/src/Registration/cReg/include/sirf/Reg/Resample.h index c9f9ee210..b85b28e3c 100644 --- a/src/Registration/cReg/include/sirf/Reg/Resample.h +++ b/src/Registration/cReg/include/sirf/Reg/Resample.h @@ -128,12 +128,6 @@ class Resample /// Backward. Alias for Adjoint virtual void backward(std::shared_ptr output_sptr, const std::shared_ptr input_sptr); - /// Get image gradient wrt transformation in place - virtual void get_image_gradient_wrt_transformation(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr source_im_sptr) = 0; - - /// Get image gradient wrt transformation - virtual std::shared_ptr > get_image_gradient_wrt_transformation(const std::shared_ptr source_im_sptr) = 0; - protected: /// Set up From 86a500e9c66f0a0591ec365c36b1e3be701c76e5 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 10:55:50 +0100 Subject: [PATCH 18/42] image grad wrt DEF TIMES IM --- src/Registration/cReg/CMakeLists.txt | 2 +- ...=> ImageGradientWRTDeformationTimesImage.cpp} | 16 ++++++++-------- ...h => ImageGradientWRTDeformationTimesImage.h} | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) rename src/Registration/cReg/{ImageGradientWRTTransformation.cpp => ImageGradientWRTDeformationTimesImage.cpp} (84%) rename src/Registration/cReg/include/sirf/Reg/{ImageGradientWRTTransformation.h => ImageGradientWRTDeformationTimesImage.h} (89%) diff --git a/src/Registration/cReg/CMakeLists.txt b/src/Registration/cReg/CMakeLists.txt index a0d1a4e9d..76fc84d95 100644 --- a/src/Registration/cReg/CMakeLists.txt +++ b/src/Registration/cReg/CMakeLists.txt @@ -35,7 +35,7 @@ SET(SOURCES "NiftiImageData3DDisplacement.cpp" "NiftiImageData3DBSpline.cpp" "ControlPointGridToDeformationConverter.cpp" - "ImageGradientWRTTransformation.cpp" + "ImageGradientWRTDeformationTimesImage.cpp" ) # If we're also wrapping to python or matlab, include the c-files diff --git a/src/Registration/cReg/ImageGradientWRTTransformation.cpp b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp similarity index 84% rename from src/Registration/cReg/ImageGradientWRTTransformation.cpp rename to src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp index 271a8b6ae..8d4cf4d6a 100644 --- a/src/Registration/cReg/ImageGradientWRTTransformation.cpp +++ b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp @@ -21,13 +21,13 @@ limitations under the License. /*! \file \ingroup Registration -\brief Class for getting image gradient WRT a transformation. +\brief Class for getting image gradient WRT a transformation and multiplying by image. \author Richard Brown \author SyneRBI */ -#include "sirf/Reg/ImageGradientWRTTransformation.h" +#include "sirf/Reg/ImageGradientWRTDeformationTimesImage.h" #include "sirf/Reg/NiftyResample.h" #include "sirf/Reg/NiftiImageData3DDeformation.h" @@ -35,7 +35,7 @@ using namespace sirf; template void -ImageGradientWRTTransformation:: +ImageGradientWRTDeformationTimesImage:: set_resampler(const std::shared_ptr > resampler_sptr) { _resampler_sptr = resampler_sptr; @@ -43,7 +43,7 @@ set_resampler(const std::shared_ptr > resampler_sptr) template void -ImageGradientWRTTransformation:: +ImageGradientWRTDeformationTimesImage:: forward(std::shared_ptr &im_sptr, const std::shared_ptr > &deformation_sptr) { @@ -55,7 +55,7 @@ forward(std::shared_ptr &im_sptr, template std::shared_ptr -ImageGradientWRTTransformation:: +ImageGradientWRTDeformationTimesImage:: forward(const std::shared_ptr > deformation_sptr) { _resampler_sptr->clear_transformations(); @@ -66,7 +66,7 @@ forward(const std::shared_ptr > defo template void -ImageGradientWRTTransformation:: +ImageGradientWRTDeformationTimesImage:: backward(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr) { _resampler_sptr->get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, image_to_multiply_sptr); @@ -74,12 +74,12 @@ backward(std::shared_ptr > &output_deforma template std::shared_ptr > -ImageGradientWRTTransformation:: +ImageGradientWRTDeformationTimesImage:: backward(const std::shared_ptr image_to_multiply_sptr) { return std::move(_resampler_sptr->get_image_gradient_wrt_deformation_times_image(image_to_multiply_sptr)); } namespace sirf { -template class ImageGradientWRTTransformation; +template class ImageGradientWRTDeformationTimesImage; } diff --git a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h similarity index 89% rename from src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h rename to src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h index 3fb3a9b1b..d81a9e250 100644 --- a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTTransformation.h +++ b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h @@ -21,7 +21,7 @@ limitations under the License. /*! \file \ingroup Registration -\brief Class for getting image gradient WRT a transformation. +\brief Class for getting image gradient WRT a transformation and multiplying by image. \author Richard Brown \author SyneRBI @@ -39,18 +39,18 @@ class ImageData; /*! \ingroup Registration -\brief Class for converting control point grids to deformation field transformations. +\brief Class for getting image gradient WRT a transformation and multiplying by image. \author Richard Brown \author SyneRBI */ template -class ImageGradientWRTTransformation +class ImageGradientWRTDeformationTimesImage { public: /// Constructor - ImageGradientWRTTransformation(); + ImageGradientWRTDeformationTimesImage(); /// Set the resampler void set_resampler(const std::shared_ptr > resampler_sptr); From d2f3e4dbed0d6ded0a0523f45844f8589627c641 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 16:58:12 +0100 Subject: [PATCH 19/42] use mean to compare images --- src/Registration/cReg/NiftiImageData.cpp | 28 +++++++++---------- .../cReg/include/sirf/Reg/NiftiImageData.h | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index a490dd391..22dcf769b 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -1487,7 +1487,7 @@ get_inner_product(const NiftiImageData &other) const } template -bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_max) +bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_mean) { if(!im1.is_initialised()) throw std::runtime_error("NiftiImageData::are_equal_to_given_accuracy: Image 1 not initialised."); @@ -1502,8 +1502,8 @@ bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData // Get required accuracy compared to the image maxes float norm; - float epsilon = (std::abs(im1.get_max())+std::abs(im2.get_max()))/2.F; - epsilon *= required_accuracy_compared_to_max; + float epsilon = (std::abs(im1.get_mean())+std::abs(im2.get_mean()))/2.F; + epsilon *= required_accuracy_compared_to_mean; // If metadata match, get the norm if (do_nifti_image_metadata_match(im1,im2, false)) @@ -1530,17 +1530,17 @@ bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData return true; std::cout << "\nImages are not equal (norm > epsilon).\n"; - std::cout << "\tmax1 = " << im1.get_max() << "\n"; - std::cout << "\tmax2 = " << im2.get_max() << "\n"; - std::cout << "\tmin1 = " << im1.get_min() << "\n"; - std::cout << "\tmin2 = " << im2.get_min() << "\n"; - std::cout << "\tmean1 = " << im1.get_mean() << "\n"; - std::cout << "\tmean2 = " << im2.get_mean() << "\n"; - std::cout << "\tstandard deviation1 = " << im1.get_standard_deviation() << "\n"; - std::cout << "\tstandard deviation2 = " << im2.get_standard_deviation() << "\n"; - std::cout << "\trequired accuracy compared to max = " << required_accuracy_compared_to_max << "\n"; - std::cout << "\tepsilon = " << epsilon << "\n"; - std::cout << "\tnorm/num_vox = " << norm << "\n"; + std::cout << "\tmax1 = " << im1.get_max() << "\n"; + std::cout << "\tmax2 = " << im2.get_max() << "\n"; + std::cout << "\tmin1 = " << im1.get_min() << "\n"; + std::cout << "\tmin2 = " << im2.get_min() << "\n"; + std::cout << "\tmean1 = " << im1.get_mean() << "\n"; + std::cout << "\tmean2 = " << im2.get_mean() << "\n"; + std::cout << "\tstandard deviation1 = " << im1.get_standard_deviation() << "\n"; + std::cout << "\tstandard deviation2 = " << im2.get_standard_deviation() << "\n"; + std::cout << "\trequired accuracy compared to mean = " << required_accuracy_compared_to_mean << "\n"; + std::cout << "\tepsilon = " << epsilon << "\n"; + std::cout << "\tnorm/num_vox = " << norm << "\n"; return false; } diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index 8600a81cc..d57409154 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -368,7 +368,7 @@ class NiftiImageData : public ImageData int get_original_datatype() const { return _original_datatype; } /// Check if the norms of two images are equal to a given accuracy. - static bool are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_max); + static bool are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_mean); /// Point is in bounds? bool is_in_bounds(const int index[7]) const; From ed1554baaf45756d47433883866221dc50e573db Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 16:58:34 +0100 Subject: [PATCH 20/42] give index access for separate values --- src/Registration/cReg/NiftiImageData.cpp | 18 ++++++++++++++++++ .../cReg/include/sirf/Reg/NiftiImageData.h | 6 ++++++ 2 files changed, 24 insertions(+) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index 22dcf769b..f67c71a4a 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -331,6 +331,24 @@ float &NiftiImageData::operator()(const int index[7]) return _data[index_1d]; } +template +float NiftiImageData::operator()(const int x, const int y, const int z, + const int t, const int u, const int v, + const int w) const +{ + const int idx[7] = { x, y, z, t, u, v, w }; + return (*this)(idx); +} + +template +float &NiftiImageData::operator()(const int x, const int y, const int z, + const int t, const int u, const int v, + const int w) +{ + const int idx[7] = { x, y, z, t, u, v, w }; + return (*this)(idx); +} + template std::shared_ptr NiftiImageData::get_raw_nifti_sptr() const { diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index d57409154..13444a74a 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -297,6 +297,12 @@ class NiftiImageData : public ImageData /// Access data element via 7D index float &operator()(const int index[7]); + /// Access data element via 7D index (const) + float operator()(const int x, const int y, const int z, const int t=0, const int u=0, const int v=0, const int w=0) const; + + /// Access data element via 7D index + float &operator()(const int x, const int y, const int z, const int t=0, const int u=0, const int v=0, const int w=0); + /// Is the image initialised? (Should be unless default constructor was used.) bool is_initialised() const { return (_nifti_image && _data && _nifti_image->datatype == DT_FLOAT32 ? true : false); } From 564b42748b75431c36782550f87c11d878610ce5 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 17:00:04 +0100 Subject: [PATCH 21/42] add tensor maths with scalar image --- .../cReg/NiftiImageData3DTensor.cpp | 33 ++++++++++++++++--- .../include/sirf/Reg/NiftiImageData3DTensor.h | 6 ++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData3DTensor.cpp b/src/Registration/cReg/NiftiImageData3DTensor.cpp index 92341d161..aaa15d6bb 100644 --- a/src/Registration/cReg/NiftiImageData3DTensor.cpp +++ b/src/Registration/cReg/NiftiImageData3DTensor.cpp @@ -179,12 +179,14 @@ void NiftiImageData3DTensor::flip_component(const int dim) template void NiftiImageData3DTensor:: -multiply_tensor_component -(const int dim, const std::shared_ptr &scalar_im_sptr) +tensor_component_maths( + const int dim, + const std::shared_ptr &scalar_im_sptr, + const typename NiftiImageData::MathsType maths_type) { // Check the dimension to multiply, that dims==5 and nu==3 if (dim < 0 || dim > 2) - throw std::runtime_error("\n\tDimension to multiply should be between 0 and 2."); + throw std::runtime_error("\n\tDimension to do tensor maths should be between 0 and 2."); std::shared_ptr > nii_scalar_im_sptr = std::dynamic_pointer_cast >(scalar_im_sptr); @@ -205,8 +207,29 @@ multiply_tensor_component // Start index is therefore = dim_number * num_voxels/3 const unsigned tensor_index_offset = dim * int(this->_nifti_image->nvox/3); - for (unsigned i=0; iget_num_voxels(); ++i) - (*this)(i+tensor_index_offset) *= (*nii_scalar_im_sptr)(i); + for (unsigned i=0; iget_num_voxels(); ++i) { + if (maths_type == NiftiImageData::mul) + (*this)(i+tensor_index_offset) *= (*nii_scalar_im_sptr)(i); + else if (maths_type == NiftiImageData::add) + (*this)(i+tensor_index_offset) += (*nii_scalar_im_sptr)(i); + } +} + +template +void NiftiImageData3DTensor:: +multiply_tensor_component +(const int dim, const std::shared_ptr &scalar_im_sptr) +{ + this->tensor_component_maths(dim, scalar_im_sptr, NiftiImageData::mul); +} + + +template +void NiftiImageData3DTensor:: +add_to_tensor_component +(const int dim, const std::shared_ptr &scalar_im_sptr) +{ + this->tensor_component_maths(dim, scalar_im_sptr, NiftiImageData::add); } namespace sirf { diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h index 270ebfe3f..d7862cb21 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h @@ -90,9 +90,15 @@ class NiftiImageData3DTensor : public NiftiImageData /// Flip component of nu void flip_component(const int dim); + /// Tensor component maths + void tensor_component_maths(const int dim, const std::shared_ptr &scalar_im_sptr, const typename NiftiImageData::MathsType maths_type); + /// Multiply tensor component by image void multiply_tensor_component(const int dim, const std::shared_ptr &scalar_im_sptr); + /// Add image to tensor component + void add_to_tensor_component(const int dim, const std::shared_ptr &scalar_im_sptr); + virtual ObjectHandle* new_data_container_handle() const { return new ObjectHandle From abcd3f5d1f981079c1e290c60f957a95c3358319 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 17:00:59 +0100 Subject: [PATCH 22/42] update functionality --- .../ImageGradientWRTDeformationTimesImage.cpp | 10 +++++----- src/Registration/cReg/NiftyResample.cpp | 17 +++++++++++++---- .../Reg/ImageGradientWRTDeformationTimesImage.h | 8 ++++---- .../cReg/include/sirf/Reg/NiftyResample.h | 2 +- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp index 8d4cf4d6a..56719e8c6 100644 --- a/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp +++ b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp @@ -44,8 +44,8 @@ set_resampler(const std::shared_ptr > resampler_sptr) template void ImageGradientWRTDeformationTimesImage:: -forward(std::shared_ptr &im_sptr, - const std::shared_ptr > &deformation_sptr) +forward(std::shared_ptr im_sptr, + const std::shared_ptr > deformation_sptr) { _resampler_sptr->clear_transformations(); _resampler_sptr->add_transformation(deformation_sptr); @@ -54,14 +54,14 @@ forward(std::shared_ptr &im_sptr, } template -std::shared_ptr +std::shared_ptr ImageGradientWRTDeformationTimesImage:: forward(const std::shared_ptr > deformation_sptr) { _resampler_sptr->clear_transformations(); _resampler_sptr->add_transformation(deformation_sptr); _resampler_sptr->process(); - return _resampler_sptr->get_output_sptr(); + return _resampler_sptr->get_output_sptr()->clone(); } template @@ -73,7 +73,7 @@ backward(std::shared_ptr > &output_deforma } template -std::shared_ptr > +std::shared_ptr > ImageGradientWRTDeformationTimesImage:: backward(const std::shared_ptr image_to_multiply_sptr) { diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index fc61f019c..22d151be1 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -365,11 +365,11 @@ get_image_gradient_wrt_deformation_times_image( if (this->_floating_image_niftis.is_complex() || image_to_multiply_sptr->is_complex()) throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image not yet implemented for complex images"); - // Get raw nifti_image pointer - nifti_image * floating_nii_ptr = this->_floating_image_niftis.real()->clone()->get_raw_nifti_sptr().get(); + // Get real part of floating image + std::shared_ptr > floating_sptr = this->_floating_image_niftis.real()->clone(); // Get image gradient - reg_getImageGradient(floating_nii_ptr, + reg_getImageGradient(floating_sptr->get_raw_nifti_sptr().get(), output_deformation_sptr->get_raw_nifti_sptr().get(), this->_deformation_sptr->get_raw_nifti_sptr().get(), nullptr, @@ -377,13 +377,22 @@ get_image_gradient_wrt_deformation_times_image( this->_padding_value, 0); + std::shared_ptr > temp = output_deformation_sptr->clone(); + // Now multiply the scalar image to each of the DVF components for (unsigned i=0; i<3; ++i) output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); + + NiftiImageData::print_headers({this->_floating_image_niftis.real().get(), + this->_deformation_sptr.get(), + temp.get(), + output_deformation_sptr.get(), + std::dynamic_pointer_cast >(image_to_multiply_sptr).get()}); +// throw std::runtime_error("hi im here"); } template -std::shared_ptr > +std::shared_ptr > NiftyResample:: get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr) { diff --git a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h index d81a9e250..a714109e0 100644 --- a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h +++ b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h @@ -50,22 +50,22 @@ class ImageGradientWRTDeformationTimesImage public: /// Constructor - ImageGradientWRTDeformationTimesImage(); + ImageGradientWRTDeformationTimesImage() {} /// Set the resampler void set_resampler(const std::shared_ptr > resampler_sptr); /// Forward in place (resample image) - virtual void forward(std::shared_ptr &im_sptr, const std::shared_ptr > &deformation_sptr); + virtual void forward(std::shared_ptr im_sptr, const std::shared_ptr > deformation_sptr); /// Forward (resample image) - virtual std::shared_ptr forward(const std::shared_ptr > deformation_sptr); + virtual std::shared_ptr forward(const std::shared_ptr > deformation_sptr); /// Backward in place (get image gradient wrt transformation) virtual void backward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr image_to_multiply_sptr); /// Backward (get image gradient wrt transformation) - virtual std::shared_ptr > backward(const std::shared_ptr image_to_multiply_sptr); + virtual std::shared_ptr > backward(const std::shared_ptr image_to_multiply_sptr); private: diff --git a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h index 5e1e8f405..3cf4217b1 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h @@ -127,7 +127,7 @@ class NiftyResample : public Resample virtual void get_image_gradient_wrt_deformation_times_image(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr); /// Get image gradient wrt transformation - virtual std::shared_ptr > get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr); + virtual std::shared_ptr > get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr); protected: From fd311bf2815d33a38be4459e1e5b6f9157478e9b Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 17:01:27 +0100 Subject: [PATCH 23/42] update testing --- src/Registration/cReg/tests/test_cReg.cpp | 86 +++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index dbda0ff3d..250e8c8be 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -39,8 +39,10 @@ limitations under the License. #include "sirf/Reg/Quaternion.h" #include "sirf/Reg/NiftiImageData3DBSpline.h" #include "sirf/Reg/ControlPointGridToDeformationConverter.h" +#include "sirf/Reg/ImageGradientWRTDeformationTimesImage.h" #include #include +#include #ifdef SIRF_SPM #include "sirf/Reg/SPMRegistration.h" #endif @@ -160,6 +162,8 @@ int main(int argc, char* argv[]) const std::shared_ptr > ref_f3d (new NiftiImageData3D( ref_f3d_filename )); const std::shared_ptr > flo_f3d (new NiftiImageData3D( flo_f3d_filename )); + // Need a random seex + srand(time(NULL)); { std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Starting NiftiImageData test...\n"; @@ -1295,6 +1299,88 @@ int main(int argc, char* argv[]) std::cout << "// Finished CGP<->DVF test.\n"; std::cout << "//------------------------------------------------------------------------ //\n"; } + { + std::cout << "// ----------------------------------------------------------------------- //\n"; + std::cout << "// Starting im grad wrt def (times image) test...\n"; + std::cout << "//------------------------------------------------------------------------ //\n"; + + // Get some images to use as template and Gaussian smooth (to reduce number of non-zero voxels) + std::shared_ptr > lambda_sptr = ref_aladin->clone(); + lambda_sptr->kernel_convolution(10.f); + + std::shared_ptr > lambda_hat_sptr = lambda_sptr->clone(); + NiftiImageData3DDisplacement disp; + disp.create_from_3D_image(*lambda_sptr); + std::shared_ptr > deformation_sptr = + std::make_shared >(); + *deformation_sptr = rand_dvf(disp,-1.f,1.f); + + // We'll need a niftyreg resampler + std::shared_ptr > nr_sptr = + std::make_shared >(); + nr_sptr->set_reference_image(lambda_hat_sptr); + nr_sptr->set_floating_image(lambda_sptr); + nr_sptr->set_interpolation_type_to_cubic_spline(); + nr_sptr->set_padding_value(0.f); + nr_sptr->add_transformation(deformation_sptr); + + // And we'll need the ImageGradientWRTDeformationTimesImage class (call it a resampler) + ImageGradientWRTDeformationTimesImage resampler; + resampler.set_resampler(nr_sptr); + + // lambda hat is forward of lambda + resampler.forward(lambda_hat_sptr, deformation_sptr); + + // Rand voxel inside of image (let's ignore a 30% margin around the edge in x,y,z directions) + const float margin = 0.3; + std::random_device rd; // obtain a random number from hardware + std::mt19937 gen(rd()); // seed the generator + const int *dims = lambda_hat_sptr->get_dimensions(); + int min_idx, max_idx, rand_idx[3]; + for (unsigned i=0; i<3; ++i) { + min_idx = int(margin * float(dims[i+1])); + max_idx = dims[i+1] - min_idx; + std::uniform_int_distribution<> distr(min_idx, max_idx); + rand_idx[i] = distr(gen); + } + + // lambda tilde is copy of lambda hat with all voxels = 0, and 1 voxel = cnst + std::shared_ptr > lambda_tilde_sptr = lambda_hat_sptr->clone(); + lambda_tilde_sptr->fill(0.f); + const float val_range[2] = {1., 10.}; + const float rand_val = static_cast(rand()) / static_cast(RAND_MAX/(val_range[1]-val_range[0])); + (*lambda_tilde_sptr)(rand_idx[0],rand_idx[1],rand_idx[2]) = rand_val; + + // img grad wrt dvf times image + auto dvf1_sptr = resampler.backward(lambda_tilde_sptr); + + // dvf2 is a clone of dvf1. initially filled with zeroes + auto dvf2_sptr = dvf1_sptr->clone(); + dvf2_sptr->fill(0.f); + + // Need a small permutation, epsilon + const float epsilon_range[2] = {1., 2.}; + const float epsilon = static_cast(rand()) / static_cast(RAND_MAX/(epsilon_range[1]-epsilon_range[0])); + + // Loop over the 3 tensor components: u=[x,y,z] + for (unsigned u=0; u<3; ++u) { + std::shared_ptr > d_shifted_sptr = deformation_sptr->clone(); + (*d_shifted_sptr)(rand_idx[0],rand_idx[1],rand_idx[2],0,u) = epsilon; + std::shared_ptr > d_lambda_times_rand_val_sptr = lambda_hat_sptr->clone(); + *d_lambda_times_rand_val_sptr = (*resampler.forward(d_shifted_sptr) - *lambda_hat_sptr) / epsilon; + *d_lambda_times_rand_val_sptr *= rand_val; + dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); + } + + if (*dvf1_sptr != *dvf2_sptr) { + NiftiImageData::print_headers({dvf1_sptr.get(), dvf2_sptr.get()}); + throw std::runtime_error("im grad wrt def (times image) test failed"); + } + + std::cout << "// ----------------------------------------------------------------------- //\n"; + std::cout << "// Finished im grad wrt def (times image) test.\n"; + std::cout << "//------------------------------------------------------------------------ //\n"; + } { std::cout << "// ----------------------------------------------------------------------- //\n"; From 448967acbad0b268bf1e5cefd37bebbbd7348108 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 9 Jul 2020 17:11:39 +0100 Subject: [PATCH 24/42] codacy --- src/Registration/cReg/tests/test_cReg.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 250e8c8be..352ce7688 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1309,10 +1309,12 @@ int main(int argc, char* argv[]) lambda_sptr->kernel_convolution(10.f); std::shared_ptr > lambda_hat_sptr = lambda_sptr->clone(); + // Create blank displacement, same size as lambda NiftiImageData3DDisplacement disp; disp.create_from_3D_image(*lambda_sptr); std::shared_ptr > deformation_sptr = std::make_shared >(); + // deformation is represented as a random displacement with min and max of -1 and 1, respectively *deformation_sptr = rand_dvf(disp,-1.f,1.f); // We'll need a niftyreg resampler @@ -1336,10 +1338,10 @@ int main(int argc, char* argv[]) std::random_device rd; // obtain a random number from hardware std::mt19937 gen(rd()); // seed the generator const int *dims = lambda_hat_sptr->get_dimensions(); - int min_idx, max_idx, rand_idx[3]; + int rand_idx[3]; for (unsigned i=0; i<3; ++i) { - min_idx = int(margin * float(dims[i+1])); - max_idx = dims[i+1] - min_idx; + const int min_idx = int(margin * float(dims[i+1])); + const int max_idx = dims[i+1] - min_idx; std::uniform_int_distribution<> distr(min_idx, max_idx); rand_idx[i] = distr(gen); } From fa4faac109be572a3f773cf15e42191df370e2b9 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 10 Jul 2020 13:13:08 +0100 Subject: [PATCH 25/42] fill from image --- src/Registration/cReg/NiftiImageData.cpp | 10 ++++++++++ .../cReg/include/sirf/Reg/NiftiImageData.h | 3 +++ 2 files changed, 13 insertions(+) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index f67c71a4a..f64b4b2c2 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -493,6 +493,16 @@ void NiftiImageData::fill(const dataType *v) _data[i] = v[i]; } + +template +void NiftiImageData::fill(const NiftiImageData &im) +{ + if(!im.is_initialised()) + throw std::runtime_error("NiftiImageData::fill(): Argument image not initialised."); + + this->fill(static_cast(im.get_raw_nifti_sptr()->data)); +} + template float NiftiImageData::get_norm(const NiftiImageData& other) const { diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index 13444a74a..d8d154d4e 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -346,6 +346,9 @@ class NiftiImageData : public ImageData /// Fill from array void fill(const dataType *v); + /// Fill from array + void fill(const NiftiImageData &im); + /// Get norm float get_norm(const NiftiImageData&) const; From 73627b6d8b6649d30fc0d9c3b617531b610df45f Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 10 Jul 2020 13:13:21 +0100 Subject: [PATCH 26/42] correct /= --- src/Registration/cReg/NiftiImageData.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index f64b4b2c2..1c4770741 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -297,7 +297,7 @@ NiftiImageData& NiftiImageData::operator*=(const float val) template NiftiImageData& NiftiImageData::operator/=(const float val) { - maths(1.f/val, add); + maths(1.f/val, mul); return *this; } From 2af785e63de2bd8449fbce25893df35b1209fcb3 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 10 Jul 2020 13:13:39 +0100 Subject: [PATCH 27/42] update sto_xyz/sto_ijk when cropping image --- src/Registration/cReg/NiftiImageData.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index 1c4770741..7c2f03059 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -820,10 +820,14 @@ void NiftiImageData::crop(const int min_index[7], const int max_index[ } // If the minimum has been changed, need to alter the origin. - for (int i=0; i<3; ++i) + for (int i=0; i<3; ++i) { _nifti_image->qto_ijk.m[i][3] -= min_idx[i]; + _nifti_image->sto_ijk.m[i][3] -= min_idx[i]; + } _nifti_image->qto_xyz = nifti_mat44_inverse(_nifti_image->qto_ijk); + _nifti_image->sto_xyz = + nifti_mat44_inverse(_nifti_image->sto_ijk); nifti_mat44_to_quatern( _nifti_image->qto_xyz, &_nifti_image->quatern_b, &_nifti_image->quatern_c, From 6fded33a71e03c0ae2551b880145961eb607ce23 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 10 Jul 2020 13:14:30 +0100 Subject: [PATCH 28/42] current attempt --- src/Registration/cReg/NiftyResample.cpp | 7 --- src/Registration/cReg/tests/test_cReg.cpp | 66 +++++++++++++++++------ 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index 22d151be1..a6e4c5765 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -382,13 +382,6 @@ get_image_gradient_wrt_deformation_times_image( // Now multiply the scalar image to each of the DVF components for (unsigned i=0; i<3; ++i) output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); - - NiftiImageData::print_headers({this->_floating_image_niftis.real().get(), - this->_deformation_sptr.get(), - temp.get(), - output_deformation_sptr.get(), - std::dynamic_pointer_cast >(image_to_multiply_sptr).get()}); -// throw std::runtime_error("hi im here"); } template diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 352ce7688..cf5911c4e 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -56,6 +56,16 @@ void check_non_zero(const NiftiImageData &im, if (std::abs(im.get_min()) < 1e-4f && std::abs(im.get_max()) < 1e-4f) throw std::runtime_error(explanation + ": contains no non-zeroes"); } +static float get_rand_val(const float min, const float max) { + float random = ((float) rand()) / (float) RAND_MAX; + float diff = max - min; + float r = random * diff; + return min + r; +} + +static float get_rand_val(const float range[]) { + return get_rand_val(range[0], range[1]); +} NiftiImageData3DDeformation CPG2DVF(const ControlPointGridToDeformationConverter &converter, const NiftiImageData3DBSpline &cpg) @@ -80,7 +90,7 @@ rand_dvf( const float min_disp = -10.f, const float max_disp = 10.f) { for (unsigned i=0; i(rand()) /(static_cast(RAND_MAX/(max_disp-min_disp))); + disp(i) = get_rand_val(min_disp, max_disp); auto dvf = NiftiImageData3DDeformation(disp); check_non_zero(dvf, "Rand DVF"); return dvf; @@ -1306,16 +1316,20 @@ int main(int argc, char* argv[]) // Get some images to use as template and Gaussian smooth (to reduce number of non-zero voxels) std::shared_ptr > lambda_sptr = ref_aladin->clone(); - lambda_sptr->kernel_convolution(10.f); + const int min[7] = {30,30,30,0,0,0,0}; + const int max[7] = {35,35,35,0,0,0,0}; + lambda_sptr->crop(min,max); +// lambda_sptr->kernel_convolution(10.f); std::shared_ptr > lambda_hat_sptr = lambda_sptr->clone(); // Create blank displacement, same size as lambda NiftiImageData3DDisplacement disp; disp.create_from_3D_image(*lambda_sptr); std::shared_ptr > deformation_sptr = - std::make_shared >(); +// std::make_shared >(); + std::make_shared >(disp); // deformation is represented as a random displacement with min and max of -1 and 1, respectively - *deformation_sptr = rand_dvf(disp,-1.f,1.f); +// *deformation_sptr = rand_dvf(disp,-1.f,1.f); // We'll need a niftyreg resampler std::shared_ptr > nr_sptr = @@ -1344,13 +1358,15 @@ int main(int argc, char* argv[]) const int max_idx = dims[i+1] - min_idx; std::uniform_int_distribution<> distr(min_idx, max_idx); rand_idx[i] = distr(gen); + rand_idx[i] = 3; } // lambda tilde is copy of lambda hat with all voxels = 0, and 1 voxel = cnst std::shared_ptr > lambda_tilde_sptr = lambda_hat_sptr->clone(); - lambda_tilde_sptr->fill(0.f); + lambda_tilde_sptr->fill(1.f); const float val_range[2] = {1., 10.}; - const float rand_val = static_cast(rand()) / static_cast(RAND_MAX/(val_range[1]-val_range[0])); +// const float rand_val = get_rand_val(val_range); + const float rand_val = 1.f; (*lambda_tilde_sptr)(rand_idx[0],rand_idx[1],rand_idx[2]) = rand_val; // img grad wrt dvf times image @@ -1361,23 +1377,43 @@ int main(int argc, char* argv[]) dvf2_sptr->fill(0.f); // Need a small permutation, epsilon - const float epsilon_range[2] = {1., 2.}; - const float epsilon = static_cast(rand()) / static_cast(RAND_MAX/(epsilon_range[1]-epsilon_range[0])); + const float epsilon_range[2] = {.5f, 1.f}; + const float epsilon = 3.f;//get_rand_val(epsilon_range); + + std::shared_ptr > d_shifted_sptr = deformation_sptr->clone(); + std::shared_ptr > d_lambda_times_rand_val_sptr = lambda_hat_sptr->clone(); // Loop over the 3 tensor components: u=[x,y,z] - for (unsigned u=0; u<3; ++u) { - std::shared_ptr > d_shifted_sptr = deformation_sptr->clone(); - (*d_shifted_sptr)(rand_idx[0],rand_idx[1],rand_idx[2],0,u) = epsilon; - std::shared_ptr > d_lambda_times_rand_val_sptr = lambda_hat_sptr->clone(); - *d_lambda_times_rand_val_sptr = (*resampler.forward(d_shifted_sptr) - *lambda_hat_sptr) / epsilon; - *d_lambda_times_rand_val_sptr *= rand_val; - dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); + for (unsigned i=0; iget_num_voxels(); ++i) { + if (i % 100 == 0) + std::cout << "\ndoing " << i << " of " << lambda_hat_sptr->get_num_voxels() << "\n" << std::flush; + if (std::abs((*lambda_tilde_sptr)(i)) < 1e-4f) + continue; + for (unsigned u=0; u<3; ++u) { + d_shifted_sptr->fill(*deformation_sptr); + (*d_shifted_sptr)(u*lambda_hat_sptr->get_num_voxels()+i) += epsilon; + d_lambda_times_rand_val_sptr->fill(*lambda_hat_sptr); + std::shared_ptr > res_forward = + std::dynamic_pointer_cast > (resampler.forward(d_shifted_sptr)); + *d_lambda_times_rand_val_sptr = (*res_forward - *lambda_hat_sptr); + *d_lambda_times_rand_val_sptr /= epsilon; + *d_lambda_times_rand_val_sptr *= rand_val; + dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); + } } + lambda_sptr->write("/Users/rich/Desktop/original_im"); + deformation_sptr->write("/Users/rich/Desktop/original_dvf"); + dvf1_sptr->write("/Users/rich/Desktop/dvf1"); + dvf2_sptr->write("/Users/rich/Desktop/dvf2"); + lambda_tilde_sptr->write("/Users/rich/Desktop/lambda_tilde"); + if (*dvf1_sptr != *dvf2_sptr) { NiftiImageData::print_headers({dvf1_sptr.get(), dvf2_sptr.get()}); throw std::runtime_error("im grad wrt def (times image) test failed"); } + std::cout << "\n\n\n\n\n SUCCESS \n\n\n\n"; + exit(0); std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Finished im grad wrt def (times image) test.\n"; From b8b4a8adc99d94f87b153aab794c8f97ae40f597 Mon Sep 17 00:00:00 2001 From: richard Date: Tue, 14 Jul 2020 12:28:31 +0100 Subject: [PATCH 29/42] fix Reg python import --- src/Registration/pReg/Reg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index c00035ec3..53ac98694 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -22,7 +22,7 @@ import sys import inspect -from sirf.Utilities import error, check_status, try_calling, +from sirf.Utilities import error, check_status, try_calling, \ assert_validity from sirf import SIRF import pyiutilities as pyiutil From 774e6892dddab3019a6cdb116ee07fbe406be9e6 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 17 Jul 2020 11:15:51 +0100 Subject: [PATCH 30/42] update to align with niftymomo --- .../cReg/ControlPointGridToDeformationConverter.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index 058907311..387b06353 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -86,7 +86,8 @@ backward(const NiftiImageData3DDeformation &dvf) const // Get cpg_ptr nifti_image *cpg_ptr = bspline.GetTransformationAsImage(); // Convert DVF to CPG - cpg_ptr->data = bspline.GetDVFGradientWRTTransformationParameters(dvf.clone()->get_raw_nifti_sptr().get(), ref_ptr); + std::shared_ptr > dvf_sptr = dvf.clone(); + cpg_ptr->data = bspline.GetDVFGradientWRTTransformationParameters(dvf_sptr->get_raw_nifti_sptr().get()); cpg_ptr->intent_p1 = SPLINE_VEL_GRID; return NiftiImageData3DBSpline(*cpg_ptr); } From 33ddd207e70ef9298cfa5d20ecaa241a6f124e15 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 17 Jul 2020 11:24:57 +0100 Subject: [PATCH 31/42] remove temp variable --- src/Registration/cReg/NiftyResample.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index a6e4c5765..cb3ff0901 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -377,8 +377,6 @@ get_image_gradient_wrt_deformation_times_image( this->_padding_value, 0); - std::shared_ptr > temp = output_deformation_sptr->clone(); - // Now multiply the scalar image to each of the DVF components for (unsigned i=0; i<3; ++i) output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); From bafd8818131870abac9c15ddf3324d230ebd2e77 Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 17 Jul 2020 15:37:02 +0100 Subject: [PATCH 32/42] divide by spacing --- .../cReg/NiftiImageData3DTensor.cpp | 41 +++++++++++++++++++ src/Registration/cReg/NiftyResample.cpp | 9 +++- .../include/sirf/Reg/NiftiImageData3DTensor.h | 9 ++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/Registration/cReg/NiftiImageData3DTensor.cpp b/src/Registration/cReg/NiftiImageData3DTensor.cpp index aaa15d6bb..9d078166c 100644 --- a/src/Registration/cReg/NiftiImageData3DTensor.cpp +++ b/src/Registration/cReg/NiftiImageData3DTensor.cpp @@ -232,6 +232,47 @@ add_to_tensor_component this->tensor_component_maths(dim, scalar_im_sptr, NiftiImageData::add); } +template +void NiftiImageData3DTensor:: +tensor_component_maths( + const int dim, + const dataType scalar, + const typename NiftiImageData::MathsType maths_type) +{ + // Check the dimension to multiply, that dims==5 and nu==3 + if (dim < 0 || dim > 2) + throw std::runtime_error("\n\tDimension to do tensor maths should be between 0 and 2."); + + // Data is ordered such that the multicomponent info is last. + // So, the first third of the data is the x-values, second third is y and last third is z. + // Start index is therefore = dim_number * num_voxels/3 + const unsigned tensor_index_offset = dim * int(this->_nifti_image->nvox/3); + + for (unsigned i=0; iget_num_voxels()/3; ++i) { + if (maths_type == NiftiImageData::mul) + (*this)(i+tensor_index_offset) *= scalar; + else if (maths_type == NiftiImageData::add) + (*this)(i+tensor_index_offset) += scalar; + } +} + +template +void NiftiImageData3DTensor:: +multiply_tensor_component +(const int dim, const dataType scalar) +{ + this->tensor_component_maths(dim, scalar, NiftiImageData::mul); +} + + +template +void NiftiImageData3DTensor:: +add_to_tensor_component +(const int dim, const dataType scalar) +{ + this->tensor_component_maths(dim, scalar, NiftiImageData::add); +} + namespace sirf { template class NiftiImageData3DTensor; } diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index cb3ff0901..bd9d5ef27 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -377,9 +377,16 @@ get_image_gradient_wrt_deformation_times_image( this->_padding_value, 0); + + const float *im_spacing = output_deformation_sptr->get_raw_nifti_sptr()->pixdim; + // Now multiply the scalar image to each of the DVF components - for (unsigned i=0; i<3; ++i) + for (unsigned i=0; i<3; ++i) { output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); + // divide by spacing to get to mm + output_deformation_sptr->multiply_tensor_component(i, 1.f/im_spacing[i+1]); + + } } template diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h index d7862cb21..6d5254f70 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData3DTensor.h @@ -99,6 +99,15 @@ class NiftiImageData3DTensor : public NiftiImageData /// Add image to tensor component void add_to_tensor_component(const int dim, const std::shared_ptr &scalar_im_sptr); + /// Tensor component maths with scalar + void tensor_component_maths(const int dim, const dataType scalar, const typename NiftiImageData::MathsType maths_type); + + /// Multiply tensor component by scalar + void multiply_tensor_component(const int dim, const dataType scalar); + + /// Add image to tensor component + void add_to_tensor_component(const int dim, const dataType scalar); + virtual ObjectHandle* new_data_container_handle() const { return new ObjectHandle From ad660d143db2df97f9e95f3fb22d041dff6383ca Mon Sep 17 00:00:00 2001 From: richard Date: Fri, 17 Jul 2020 16:18:50 +0100 Subject: [PATCH 33/42] correct the test, only working for linear interpolation --- src/Registration/cReg/NiftyResample.cpp | 4 ++ src/Registration/cReg/tests/test_cReg.cpp | 57 ++++++++--------------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index bd9d5ef27..84f9e5b34 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -361,6 +361,10 @@ get_image_gradient_wrt_deformation_times_image( // Call the set up set_up(); + // Only tested for linear interpolation + if (this->_interpolation_type != Resample::LINEAR) + throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image only implemented for linear interpolation"); + // Not implemented for complex images if (this->_floating_image_niftis.is_complex() || image_to_multiply_sptr->is_complex()) throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image not yet implemented for complex images"); diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index cf5911c4e..83a98b638 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1316,27 +1316,24 @@ int main(int argc, char* argv[]) // Get some images to use as template and Gaussian smooth (to reduce number of non-zero voxels) std::shared_ptr > lambda_sptr = ref_aladin->clone(); - const int min[7] = {30,30,30,0,0,0,0}; + const int min[7] = {27,27,27,0,0,0,0}; const int max[7] = {35,35,35,0,0,0,0}; lambda_sptr->crop(min,max); -// lambda_sptr->kernel_convolution(10.f); std::shared_ptr > lambda_hat_sptr = lambda_sptr->clone(); // Create blank displacement, same size as lambda NiftiImageData3DDisplacement disp; disp.create_from_3D_image(*lambda_sptr); + // Convert displacement to identity deformation std::shared_ptr > deformation_sptr = -// std::make_shared >(); std::make_shared >(disp); - // deformation is represented as a random displacement with min and max of -1 and 1, respectively -// *deformation_sptr = rand_dvf(disp,-1.f,1.f); // We'll need a niftyreg resampler std::shared_ptr > nr_sptr = std::make_shared >(); nr_sptr->set_reference_image(lambda_hat_sptr); nr_sptr->set_floating_image(lambda_sptr); - nr_sptr->set_interpolation_type_to_cubic_spline(); + nr_sptr->set_interpolation_type_to_linear(); nr_sptr->set_padding_value(0.f); nr_sptr->add_transformation(deformation_sptr); @@ -1349,25 +1346,11 @@ int main(int argc, char* argv[]) // Rand voxel inside of image (let's ignore a 30% margin around the edge in x,y,z directions) const float margin = 0.3; - std::random_device rd; // obtain a random number from hardware - std::mt19937 gen(rd()); // seed the generator - const int *dims = lambda_hat_sptr->get_dimensions(); - int rand_idx[3]; - for (unsigned i=0; i<3; ++i) { - const int min_idx = int(margin * float(dims[i+1])); - const int max_idx = dims[i+1] - min_idx; - std::uniform_int_distribution<> distr(min_idx, max_idx); - rand_idx[i] = distr(gen); - rand_idx[i] = 3; - } - // lambda tilde is copy of lambda hat with all voxels = 0, and 1 voxel = cnst + // lambda tilde is copy of lambda hat with all voxels = cnst std::shared_ptr > lambda_tilde_sptr = lambda_hat_sptr->clone(); - lambda_tilde_sptr->fill(1.f); - const float val_range[2] = {1., 10.}; -// const float rand_val = get_rand_val(val_range); - const float rand_val = 1.f; - (*lambda_tilde_sptr)(rand_idx[0],rand_idx[1],rand_idx[2]) = rand_val; + const float fill_val = 2.f; + lambda_tilde_sptr->fill(fill_val); // img grad wrt dvf times image auto dvf1_sptr = resampler.backward(lambda_tilde_sptr); @@ -1377,16 +1360,14 @@ int main(int argc, char* argv[]) dvf2_sptr->fill(0.f); // Need a small permutation, epsilon - const float epsilon_range[2] = {.5f, 1.f}; - const float epsilon = 3.f;//get_rand_val(epsilon_range); + const float epsilon = 3.f; std::shared_ptr > d_shifted_sptr = deformation_sptr->clone(); std::shared_ptr > d_lambda_times_rand_val_sptr = lambda_hat_sptr->clone(); // Loop over the 3 tensor components: u=[x,y,z] for (unsigned i=0; iget_num_voxels(); ++i) { - if (i % 100 == 0) - std::cout << "\ndoing " << i << " of " << lambda_hat_sptr->get_num_voxels() << "\n" << std::flush; + // Skip if there's nothing in that voxel (to be general, but in this case, all voxels are filled) if (std::abs((*lambda_tilde_sptr)(i)) < 1e-4f) continue; for (unsigned u=0; u<3; ++u) { @@ -1397,23 +1378,23 @@ int main(int argc, char* argv[]) std::dynamic_pointer_cast > (resampler.forward(d_shifted_sptr)); *d_lambda_times_rand_val_sptr = (*res_forward - *lambda_hat_sptr); *d_lambda_times_rand_val_sptr /= epsilon; - *d_lambda_times_rand_val_sptr *= rand_val; + *d_lambda_times_rand_val_sptr *= fill_val; dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); } + // print progress + if ((i+1) % 100 == 0) + std::cout << "done " << i+1 << " resamples out of " << lambda_hat_sptr->get_num_voxels() << " for numerical gradient test...\n" << std::flush; } - lambda_sptr->write("/Users/rich/Desktop/original_im"); - deformation_sptr->write("/Users/rich/Desktop/original_dvf"); - dvf1_sptr->write("/Users/rich/Desktop/dvf1"); - dvf2_sptr->write("/Users/rich/Desktop/dvf2"); - lambda_tilde_sptr->write("/Users/rich/Desktop/lambda_tilde"); + // Crop by 1 in all directions because we don't want to compare any edge problems + const int *dvf_size = dvf1_sptr->get_dimensions(); + const int dvf_crop_min[7] = {1,1,1,-1,-1,-1,-1}; + const int dvf_crop_max[7] = {dvf_size[1]-2,dvf_size[2]-2,dvf_size[3]-2,-1,-1,-1,-1}; + dvf1_sptr->crop(dvf_crop_min, dvf_crop_max); + dvf2_sptr->crop(dvf_crop_min, dvf_crop_max); - if (*dvf1_sptr != *dvf2_sptr) { - NiftiImageData::print_headers({dvf1_sptr.get(), dvf2_sptr.get()}); + if (*dvf1_sptr != *dvf2_sptr) throw std::runtime_error("im grad wrt def (times image) test failed"); - } - std::cout << "\n\n\n\n\n SUCCESS \n\n\n\n"; - exit(0); std::cout << "// ----------------------------------------------------------------------- //\n"; std::cout << "// Finished im grad wrt def (times image) test.\n"; From 52682a85587958a1095c1a50e569af7662482cd2 Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 11:20:24 +0100 Subject: [PATCH 34/42] python partially implemented --- src/Registration/cReg/cReg.cpp | 91 ++++++++++++++++++ src/Registration/cReg/include/sirf/Reg/cReg.h | 7 ++ src/Registration/cReg/tests/test_cReg.cpp | 7 +- src/Registration/pReg/Reg.py | 56 +++++++++++ src/Registration/pReg/tests/test_pReg.py | 96 +++++++++++++++++++ 5 files changed, 251 insertions(+), 6 deletions(-) diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index 29a23213c..4b94024cd 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -28,6 +28,7 @@ limitations under the License. #include "sirf/Reg/NiftiImageData3DDeformation.h" #include "sirf/Reg/NiftiImageData3DBSpline.h" #include "sirf/Reg/ControlPointGridToDeformationConverter.h" +#include "sirf/Reg/ImageGradientWRTDeformationTimesImage.h" #include "sirf/Reg/NiftyAladinSym.h" #include "sirf/Reg/NiftyF3dSym.h" #include "sirf/Reg/NiftyResample.h" @@ -74,6 +75,8 @@ void* cReg_newObject(const char* name) return newObjectHandle(std::shared_ptr >(new NiftiImageData3DBSpline)); if (strcmp(name, "ControlPointGridToDeformationConverter") == 0) return newObjectHandle(std::shared_ptr >(new ControlPointGridToDeformationConverter)); + if (strcmp(name, "ImageGradientWRTDeformationTimesImage") == 0) + return newObjectHandle(std::make_shared >()); if (strcmp(name, "NiftyAladinSym") == 0) return newObjectHandle(std::shared_ptr >(new NiftyAladinSym)); if (strcmp(name, "NiftyF3dSym") == 0) @@ -693,6 +696,94 @@ void* cReg_CPG2DVF_backward(const void* converter_ptr, const void* dvf_ptr) CATCH; } +// -------------------------------------------------------------------------------- // +// ImageGradientWRTDeformationTimesImage +// -------------------------------------------------------------------------------- // +extern "C" +void* cReg_ImGradWRTDef_set_resampler(const void* ptr, const void* resampler_ptr) +{ + try { + ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = + objectFromHandle >(ptr); + std::shared_ptr > resampler_sptr; + getObjectSptrFromHandle >(resampler_ptr, resampler_sptr); + im_grad_wrt_def_time_im.set_resampler(resampler_sptr); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_ImGradWRTDef_forward_in_place( + const void* ptr, const void* deformation_ptr, const void* out_ptr) +{ + try { + ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = + objectFromHandle >(ptr); + // Get deformation + std::shared_ptr > deformation_sptr; + getObjectSptrFromHandle >(deformation_ptr, deformation_sptr); + // Out + std::shared_ptr out_sptr; + getObjectSptrFromHandle(out_ptr, out_sptr); + // Do it. + im_grad_wrt_def_time_im.forward(out_sptr, deformation_sptr); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_ImGradWRTDef_forward( + const void* ptr, const void* deformation_ptr) +{ + try { + ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = + objectFromHandle >(ptr); + // Get deformation + std::shared_ptr > deformation_sptr; + getObjectSptrFromHandle >(deformation_ptr, deformation_sptr); + // Do it. + auto out_sptr = im_grad_wrt_def_time_im.forward(deformation_sptr); + return newObjectHandle(out_sptr); + } + CATCH; +} +extern "C" +void* cReg_ImGradWRTDef_backward_in_place( + const void* ptr, const void* image_ptr, const void* out_ptr) +{ + try { + ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = + objectFromHandle >(ptr); + // Get image + std::shared_ptr image_sptr; + getObjectSptrFromHandle(image_ptr, image_sptr); + // Out dvf (might be null pointer) + std::shared_ptr > out_sptr; + getObjectSptrFromHandle >(out_ptr, out_sptr); + // Do it. + im_grad_wrt_def_time_im.backward(out_sptr, image_sptr); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_ImGradWRTDef_backward( + const void* ptr, const void* image_ptr) +{ + try { + ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = + objectFromHandle >(ptr); + // Get image + std::shared_ptr image_sptr; + getObjectSptrFromHandle(image_ptr, image_sptr); + // Do it. + auto out_sptr = im_grad_wrt_def_time_im.backward(image_sptr); + return newObjectHandle(out_sptr); + } + CATCH; +} + + // -------------------------------------------------------------------------------- // // Registration // -------------------------------------------------------------------------------- // diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 010468f5b..6772a05ec 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -87,6 +87,13 @@ extern "C" { void* cReg_CPG2DVF_forward(const void* converter_ptr, const void* cpg_ptr); void* cReg_CPG2DVF_backward(const void* converter_ptr, const void* dvf_ptr); + // ImageGradientWRTDeformationTimesImage + void* cReg_ImGradWRTDef_set_resampler(const void* ptr, const void* resampler_ptr); + void* cReg_ImGradWRTDef_forward_in_place(const void* ptr, const void* deformation_ptr, const void* out_ptr); + void* cReg_ImGradWRTDef_forward(const void* ptr, const void* deformation_ptr); + void* cReg_ImGradWRTDef_backward_in_place(const void* ptr, const void* image_ptr, const void* out_ptr); + void* cReg_ImGradWRTDef_backward(const void* ptr, const void* image_ptr); + // Registration void* cReg_Registration_process(void* ptr); void* cReg_Registration_get_deformation_displacement_image(const void* ptr, const char *transform_type, const int idx); diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 83a98b638..14a27160b 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1344,9 +1344,6 @@ int main(int argc, char* argv[]) // lambda hat is forward of lambda resampler.forward(lambda_hat_sptr, deformation_sptr); - // Rand voxel inside of image (let's ignore a 30% margin around the edge in x,y,z directions) - const float margin = 0.3; - // lambda tilde is copy of lambda hat with all voxels = cnst std::shared_ptr > lambda_tilde_sptr = lambda_hat_sptr->clone(); const float fill_val = 2.f; @@ -1373,12 +1370,10 @@ int main(int argc, char* argv[]) for (unsigned u=0; u<3; ++u) { d_shifted_sptr->fill(*deformation_sptr); (*d_shifted_sptr)(u*lambda_hat_sptr->get_num_voxels()+i) += epsilon; - d_lambda_times_rand_val_sptr->fill(*lambda_hat_sptr); std::shared_ptr > res_forward = std::dynamic_pointer_cast > (resampler.forward(d_shifted_sptr)); *d_lambda_times_rand_val_sptr = (*res_forward - *lambda_hat_sptr); - *d_lambda_times_rand_val_sptr /= epsilon; - *d_lambda_times_rand_val_sptr *= fill_val; + *d_lambda_times_rand_val_sptr *= (fill_val/epsilon); dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); } // print progress diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index cbf0ab846..6828e3510 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -844,6 +844,62 @@ def range_geometry(self): return self.dvf_template +class ImageGradientWRTDeformationTimesImage(object): + """ + Class for converting from control points grids to deformations and vice + versa. + """ + def __init__(self): + self.handle = None + self.name = 'ImageGradientWRTDeformationTimesImage' + self.handle = pyreg.cReg_newObject(self.name) + self.output_of_forward_method = None + check_status(self.handle) + + def __del__(self): + if self.handle is not None: + pyiutil.deleteDataHandle(self.handle) + + def set_resampler(self, resampler): + """Set resampler.""" + assert_validity(resampler, NiftyResample) + try_calling(pyreg.cReg_ImGradWRTDef_set_resampler( + self.handle, resampler.handle)) + self.output_of_forward_method = resampler.reference_image + + def forward(self, deformation, out=None): + """Forward (forward resample with given deformation).""" + assert_validity(deformation, NiftiImageData3DDeformation) + # If we need to create the output + if out is None: + out = self.output_of_forward_method.same_object() + out.handle = pyreg.cReg_ImGradWRTDef_forward( + self.handle, deformation.handle) + check_status(out.handle) + return out + # If in place + else: + assert_validity(out, SIRF.ImageData) + try_calling(pyreg.cReg_ImGradWRTDef_forward_in_place( + self.handle, deformation.handle, out.handle)) + + def backward(self, image, out=None): + """Backward (get im grad wrt deformation times image).""" + assert_validity(image, SIRF.ImageData) + # If we need to create the output + if out is None: + out = NiftiImageData3DDeformation() + out.handle = pyreg.cReg_ImGradWRTDef_backward( + self.handle, image.handle) + check_status(out.handle) + return out + # If in place + else: + assert_validity(out, NiftiImageData3DDeformation) + try_calling(pyreg.cReg_ImGradWRTDef_backward( + self.handle, image.handle, out.handle)) + + class _Registration(ABC): """Abstract base class for registration.""" diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index fc877d7c8..a89bbdd9e 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1095,6 +1095,101 @@ def try_cgp_dvf_conversion(na): time.sleep(0.5) +# Im grad wrt def +def try_im_grad_wrt_def_times_im(): + time.sleep(0.5) + sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') + sys.stderr.write('# Starting im grad wrt def (times image) test...\n') + sys.stderr.write('# --------------------------------------------------------------------------------- #\n') + time.sleep(0.5) + + # Get image to use as template and crop (to reduce number of non-zero voxels) + lambda_im = ref_aladin.deep_copy() + min_idx = [27, 27, 27] + max_idx = [35, 35, 35] + lambda_im.crop(min_idx, max_idx) + + lambda_hat = lambda_im.deep_copy() + # Create blank displacement, same size as lambda + disp = sirf.Reg.NiftiImageData3DDisplacement() + disp.create_from_3D_image(lambda_im) + # Convert displacement to identity deformation + deformation = sirf.Reg.NiftiImageData3DDeformation(disp) + + # We'll need a niftyreg resampler + nr = sirf.Reg.NiftyResample() + nr.set_reference_image(lambda_hat) + nr.set_floating_image(lambda_im) + nr.set_interpolation_type_to_linear() + nr.set_padding_value(0) + nr.add_transformation(deformation) + + # And we'll need the ImageGradientWRTDeformationTimesImage class (call it a resampler) + resampler = sirf.Reg.ImageGradientWRTDeformationTimesImage() + resampler.set_resampler(nr) + + # lambda hat is forward of lambda + resampler.forward(deformation, out=lambda_hat) + + # lambda tilde is copy of lambda hat with all voxels = cnst + lambda_tilde = lambda_hat.deep_copy() + fill_val = 2.0 + lambda_tilde.fill(fill_val) + + # img grad wrt dvf times image + dvf1 = resampler.backward(lambda_tilde) + + # dvf2 is a clone of dvf1. initially filled with zeroes + dvf2 = dvf1.deep_copy() + dvf2.fill(0.0) + + # Need a small permutation, epsilon + epsilon = 3.0 + + # Loop over the 3 tensor components: u=[x,y,z] + d_shifted = deformation.deep_copy() + d_lambda_times_rand_val = lambda_hat.deep_copy() + lambda_tilde_arr = lambda_tilde.as_array() + lambda_tilde_shape = lambda_tilde_arr.shape + deformation_arr = deformation.as_array() + for ix, iy, iz in np.ndindex(lambda_tilde_shape): + # Skip if there's nothing in that voxel (to be general, but in this case, all voxels are filled) + if abs(lambda_tilde_arr[ix, iy, iz] < 1.0e-4): + continue + for iu in range(3): + # Start with d_shifted = deformation + d_shifted_arr = deformation.as_array() + d_shifted_arr[ix, iy, iz, 0, iu] += epsilon + d_shifted.fill(d_shifted_arr) + + res_forward = resampler.forward(d_shifted) + + d_lambda_times_rand_val = res_forward - lambda_hat + d_lambda_times_rand_val *= (fill_val / epsilon) + dvf2.add_to_tensor_component(u, d_lambda_times_rand_val) + + # print progress + i = numpy.ravel_multi_index((ix,iy,iz), lambda_tilde_shape) + if (i+1) % 100 == 0: + print("done " + str(i+1) + " resamples out of " + str(lambda_hat_arr.size) << " for numerical gradient test...") + + # Crop by 1 in all directions because we don't want to compare any edge problems + dvf_size = dvf1.as_array().shape + dvf_crop_min = [1,1,1] + dvf_crop_max = [dvf_size[1]-2,dvf_size[2]-2,dvf_size[3]-2] + dvf1.crop(dvf_crop_min, dvf_crop_max) + dvf2.crop(dvf_crop_min, dvf_crop_max) + + if (dvf1 != dvf2): + raise AssertionError('im grad wrt def (times image) test failed') + + time.sleep(0.5) + sys.stderr.write('\n# --------------------------------------------------------------------------------- #\n') + sys.stderr.write('# Finished im grad wrt def (times image) test.\n') + sys.stderr.write('# --------------------------------------------------------------------------------- #\n') + time.sleep(0.5) + + # AffineTransformation def try_affinetransformation(na): time.sleep(0.5) @@ -1254,6 +1349,7 @@ def test(): # try_niftymomo(na) # try_weighted_mean(na) try_cgp_dvf_conversion(na) + try_im_grad_wrt_def_times_im() # try_affinetransformation(na) # try_quaternion() From ca8ae58a059c3ef25a12483fc00444eb226083a6 Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 14:42:47 +0100 Subject: [PATCH 35/42] continue python implementation and test --- src/Registration/cReg/cReg.cpp | 22 +++++++++++++++++++ .../cReg/include/sirf/Reg/NiftiImageData.h | 4 ++-- src/Registration/cReg/include/sirf/Reg/cReg.h | 2 ++ src/Registration/pReg/Reg.py | 21 ++++++++++++++++++ src/Registration/pReg/tests/test_pReg.py | 4 ++-- 5 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index 4b94024cd..b7f8a64c7 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -578,6 +578,28 @@ void* cReg_NiftiImageData3DTensor_get_tensor_component(const void *ptr, const in } CATCH; } +extern "C" +void* cReg_NiftiImageData3DTensor_tensor_component_maths_im(const void *ptr, const int dim, const void *im_ptr, const int maths_type) +{ + try { + NiftiImageData3DTensor& tensor = objectFromHandle >(ptr); + std::shared_ptr im_sptr; + getObjectSptrFromHandle(im_ptr, im_sptr); + tensor.tensor_component_maths(dim, im_sptr, static_cast::MathsType>(maths_type)); + return new DataHandle; + } + CATCH; +} +extern "C" +void* cReg_NiftiImageData3DTensor_tensor_component_maths_val(const void *ptr, const int dim, const float val, const int maths_type) +{ + try { + NiftiImageData3DTensor& tensor = objectFromHandle >(ptr); + tensor.tensor_component_maths(dim, val, static_cast::MathsType>(maths_type)); + return new DataHandle; + } + CATCH; +} // -------------------------------------------------------------------------------- // // NiftiImageData3DDeformation // -------------------------------------------------------------------------------- // diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index d8d154d4e..860a0a897 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -427,12 +427,12 @@ class NiftiImageData : public ImageData /// Standardise (subtract mean and divide by standard deviation). void standardise(); + enum MathsType { add, sub, mul }; + protected: enum NiftiImageDataType { _general, _3D, _3DTensor, _3DDisp, _3DDef, _3DBSpl}; - enum MathsType { add, sub, mul }; - /// Image data as a nifti object std::shared_ptr _nifti_image; diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 6772a05ec..28419385f 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -72,6 +72,8 @@ extern "C" { void* cReg_NiftiImageData3DTensor_construct_from_3_components(const char* obj, const void *x_ptr, const void *y_ptr, const void *z_ptr); void* cReg_NiftiImageData3DTensor_flip_component(const void *ptr, const int dim); void* cReg_NiftiImageData3DTensor_get_tensor_component(const void *ptr, const int dim); + void* cReg_NiftiImageData3DTensor_tensor_component_maths_im(const void *ptr, const int dim, const void *im_ptr, const int maths_type); + void* cReg_NiftiImageData3DTensor_tensor_component_maths_val(const void *ptr, const int dim, const float val, const int maths_type); // NiftiImageData3DDeformation void* cReg_NiftiImageData3DDeformation_compose_single_deformation(const void* im, const char* types, const void* trans_vector_ptr); diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index 6828e3510..a2b0540fb 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -613,6 +613,27 @@ def get_tensor_component(self, dim): check_status(output.handle) return output + def tensor_component_maths(self, dim, arg, maths_type): + """Do tensor component maths.""" + if isinstance(arg, SIRF.ImageData): + try_calling( + pyreg.cReg_NiftiImageData3DTensor_tensor_component_maths_im( + self.handle, dim, arg.handle, maths_type)) + elif isnumeric(arg): + try_calling( + pyreg.cReg_NiftiImageData3DTensor_tensor_component_maths_val( + self.handle, dim, float(arg), maths_type)) + else: + raise error("tensor_component_maths: arg should be image or scalar") + + def multiply_tensor_component(self, dim, arg): + """Multiply tensor component with image or value.""" + self.tensor_component_maths(dim, arg, 2) + + def add_to_tensor_component(self, dim, arg): + """Add to tensor component with image or value.""" + self.tensor_component_maths(dim, arg, 0) + class NiftiImageData3DDisplacement(NiftiImageData3DTensor, _Transformation): """Class for 3D displacement nifti image data. diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index a89bbdd9e..65816e59d 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1166,10 +1166,10 @@ def try_im_grad_wrt_def_times_im(): d_lambda_times_rand_val = res_forward - lambda_hat d_lambda_times_rand_val *= (fill_val / epsilon) - dvf2.add_to_tensor_component(u, d_lambda_times_rand_val) + dvf2.add_to_tensor_component(iu, d_lambda_times_rand_val) # print progress - i = numpy.ravel_multi_index((ix,iy,iz), lambda_tilde_shape) + i = np.ravel_multi_index((ix,iy,iz), lambda_tilde_shape) if (i+1) % 100 == 0: print("done " + str(i+1) + " resamples out of " + str(lambda_hat_arr.size) << " for numerical gradient test...") From b171ceb7919989e6b8b5a23af11c91e4f2d66fd6 Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 14:49:15 +0100 Subject: [PATCH 36/42] finished --- src/Registration/pReg/tests/test_pReg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index 65816e59d..57e0f08a8 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1105,8 +1105,8 @@ def try_im_grad_wrt_def_times_im(): # Get image to use as template and crop (to reduce number of non-zero voxels) lambda_im = ref_aladin.deep_copy() - min_idx = [27, 27, 27] - max_idx = [35, 35, 35] + min_idx = [28, 28, 28] + max_idx = [34, 34, 34] lambda_im.crop(min_idx, max_idx) lambda_hat = lambda_im.deep_copy() @@ -1151,6 +1151,7 @@ def try_im_grad_wrt_def_times_im(): d_lambda_times_rand_val = lambda_hat.deep_copy() lambda_tilde_arr = lambda_tilde.as_array() lambda_tilde_shape = lambda_tilde_arr.shape + lambda_tilde_numel = lambda_tilde_arr.size deformation_arr = deformation.as_array() for ix, iy, iz in np.ndindex(lambda_tilde_shape): # Skip if there's nothing in that voxel (to be general, but in this case, all voxels are filled) @@ -1171,7 +1172,7 @@ def try_im_grad_wrt_def_times_im(): # print progress i = np.ravel_multi_index((ix,iy,iz), lambda_tilde_shape) if (i+1) % 100 == 0: - print("done " + str(i+1) + " resamples out of " + str(lambda_hat_arr.size) << " for numerical gradient test...") + print("done " + str(i+1) + " resamples out of " + str(lambda_tilde_numel) + " for numerical gradient test...") # Crop by 1 in all directions because we don't want to compare any edge problems dvf_size = dvf1.as_array().shape From 380f937c665a54e289ccab592fb51e875f087b9a Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 14:49:37 +0100 Subject: [PATCH 37/42] test everything --- src/Registration/pReg/tests/test_pReg.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index 57e0f08a8..f2a1666a8 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1338,21 +1338,21 @@ def try_quaternion(): def test(): - # try_niftiimage() - # try_niftiimage3d() - # try_niftiimage3dtensor() - # try_niftiimage3ddisplacement() - # try_niftiimage3ddeformation() + try_niftiimage() + try_niftiimage3d() + try_niftiimage3dtensor() + try_niftiimage3ddisplacement() + try_niftiimage3ddeformation() na = try_niftyaladin() - # try_niftyf3d() - # try_transformations(na) - # try_resample(na) - # try_niftymomo(na) - # try_weighted_mean(na) + try_niftyf3d() + try_transformations(na) + try_resample(na) + try_niftymomo(na) + try_weighted_mean(na) try_cgp_dvf_conversion(na) try_im_grad_wrt_def_times_im() - # try_affinetransformation(na) - # try_quaternion() + try_affinetransformation(na) + try_quaternion() if __name__ == "__main__": From fae6b20384e0575a7c06a111c262f56cf06903af Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 14:53:04 +0100 Subject: [PATCH 38/42] undo comparison against image mean --- src/Registration/cReg/NiftiImageData.cpp | 28 +++++++++---------- .../cReg/include/sirf/Reg/NiftiImageData.h | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Registration/cReg/NiftiImageData.cpp b/src/Registration/cReg/NiftiImageData.cpp index 880dd07fb..7717a62b5 100644 --- a/src/Registration/cReg/NiftiImageData.cpp +++ b/src/Registration/cReg/NiftiImageData.cpp @@ -1516,7 +1516,7 @@ get_inner_product(const NiftiImageData &other) const } template -bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_mean) +bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_max) { if(!im1.is_initialised()) throw std::runtime_error("NiftiImageData::are_equal_to_given_accuracy: Image 1 not initialised."); @@ -1531,8 +1531,8 @@ bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData // Get required accuracy compared to the image maxes float norm; - float epsilon = (std::abs(im1.get_mean())+std::abs(im2.get_mean()))/2.F; - epsilon *= required_accuracy_compared_to_mean; + float epsilon = (std::abs(im1.get_max())+std::abs(im2.get_max()))/2.F; + epsilon *= required_accuracy_compared_to_max; // If metadata match, get the norm if (do_nifti_image_metadata_match(im1,im2, false)) @@ -1559,17 +1559,17 @@ bool NiftiImageData::are_equal_to_given_accuracy(const NiftiImageData return true; std::cout << "\nImages are not equal (norm > epsilon).\n"; - std::cout << "\tmax1 = " << im1.get_max() << "\n"; - std::cout << "\tmax2 = " << im2.get_max() << "\n"; - std::cout << "\tmin1 = " << im1.get_min() << "\n"; - std::cout << "\tmin2 = " << im2.get_min() << "\n"; - std::cout << "\tmean1 = " << im1.get_mean() << "\n"; - std::cout << "\tmean2 = " << im2.get_mean() << "\n"; - std::cout << "\tstandard deviation1 = " << im1.get_standard_deviation() << "\n"; - std::cout << "\tstandard deviation2 = " << im2.get_standard_deviation() << "\n"; - std::cout << "\trequired accuracy compared to mean = " << required_accuracy_compared_to_mean << "\n"; - std::cout << "\tepsilon = " << epsilon << "\n"; - std::cout << "\tnorm/num_vox = " << norm << "\n"; + std::cout << "\tmax1 = " << im1.get_max() << "\n"; + std::cout << "\tmax2 = " << im2.get_max() << "\n"; + std::cout << "\tmin1 = " << im1.get_min() << "\n"; + std::cout << "\tmin2 = " << im2.get_min() << "\n"; + std::cout << "\tmean1 = " << im1.get_mean() << "\n"; + std::cout << "\tmean2 = " << im2.get_mean() << "\n"; + std::cout << "\tstandard deviation1 = " << im1.get_standard_deviation() << "\n"; + std::cout << "\tstandard deviation2 = " << im2.get_standard_deviation() << "\n"; + std::cout << "\trequired accuracy compared to max = " << required_accuracy_compared_to_max << "\n"; + std::cout << "\tepsilon = " << epsilon << "\n"; + std::cout << "\tnorm/num_vox = " << norm << "\n"; return false; } diff --git a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h index 860a0a897..c2190408d 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftiImageData.h @@ -377,7 +377,7 @@ class NiftiImageData : public ImageData int get_original_datatype() const { return _original_datatype; } /// Check if the norms of two images are equal to a given accuracy. - static bool are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_mean); + static bool are_equal_to_given_accuracy(const NiftiImageData &im1, const NiftiImageData &im2, const float required_accuracy_compared_to_max); /// Point is in bounds? bool is_in_bounds(const int index[7]) const; From d5d435fbdc7b6e0f5697c9e61c3f0b6b68baca34 Mon Sep 17 00:00:00 2001 From: richard Date: Wed, 22 Jul 2020 18:46:51 +0100 Subject: [PATCH 39/42] registration attempt --- examples/Python/PETMR/registration.py | 276 ++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 examples/Python/PETMR/registration.py diff --git a/examples/Python/PETMR/registration.py b/examples/Python/PETMR/registration.py new file mode 100644 index 000000000..ca8a5f925 --- /dev/null +++ b/examples/Python/PETMR/registration.py @@ -0,0 +1,276 @@ +"""Registration demo. + +Usage: + registration [--help | options] + +Options: + -R , --ref= reference image + -F , --flo= floating image + --cpg_downsample= factor to downsample control point grid spacing + relative to reference image (e.g., 2 will mean + cpg spacing will be double that of reference + image) [default: 2] + --templ_sino= template sinogram + --rands= randoms sinogram + --attn= attenuation image + --norm= ECAT8 normalisation file + --num_iters= number of iterations [default: 10] + --num_subsets= number of subsets for PET projection [default: 21] + --space= space in which to perform registration (image or + sinogram) [default: sinogram] + --gpu use GPU +""" + +# SyneRBI Synergistic Image Reconstruction Framework (SIRF) +# Copyright 2015 - 2020 University College London. +# +# This is software developed for the Collaborative Computational +# Project in Synergistic Reconstruction for Biomedical Imaging +# (formerly CCP PETMR) +# (http://www.ccpsynerbi.ac.uk/). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from docopt import docopt +from tqdm import tqdm +from sirf.Utilities import error +import sirf.STIR +import sirf.Reg + +__version__ = '0.1.0' +args = docopt(__doc__, version=__version__) + + +# process command-line options +ref_file = args['--ref'] +flo_file = args['--flo'] +ref_eng = args['--ref_eng'] +flo_eng = args['--flo_eng'] +cpg_downsample_factor = float(args['--cpg_downsample']) +sino_file = args['--templ_sino'] +rands_file = args['--rands'] +attn_file = args['--attn'] +norm_file = args['--norm'] +num_iters = float(args['--num_iters']) +num_subsets = float(args['--num_subsets']) +registration_space = args['--space'] +if registration_space not in {'sinogram', 'image'}: + raise error("Unknown registration space: " + registration_space) +use_gpu = True if args['--gpu'] else False + + +def check_file_exists(fname): + """Check file exists. Else, throw error.""" + if not os.path.isfile(filename): + raise error("File not found: " + filename) + + +def get_image(filename): + """Get an image from its filename.""" + check_file_exists(filename) + return sirf.STIR.ImageData(filename) + + +def get_sinogram(filename): + """Get an sinogram from its filename.""" + check_file_exists(filename) + return sirf.STIR.AcquisitionData(filename) + + +def get_cpg_2_dvf_converter(ref): + """Get CPG 2 DVF converter.""" + cpg_spacing = ref.get_voxel_sizes()[1:4] * cpg_downsample_factor + cpg_2_dvf_converter = sirf.Reg.ControlPointGridToDeformationConverter() + cpg_2_dvf_converter.set_cpg_spacing(cpg_spacing) + cpg_2_dvf_converter.set_reference_image(ref) + return cpg_2_dvf_converter + + +def get_dvf(ref): + """Get initial DVF.""" + disp = sirf.Reg.NiftiImageData3DDisplacement() + disp.create_from_3D_image(ref) + disp.fill(0) + dvf = sirf.Reg.NiftiImageData3DDeformation(disp) + return dvf + + +def get_resampler(ref, flo): + """Get resampler.""" + nr = sirf.Reg.NiftyResample() + nr.set_reference_image(ref) + nr.set_floating_image(flo) + nr.set_interpolation_type_to_linear() + nr.set_padding_value(0) + return nr + + +def get_asm_norm(norm_file): + if norm_file is None: + return None + check_file_exists(norm_file) + asm_norm = pet.AcquisitionSensitivityModel(norm_file) + return asm_norm + + +def get_asm_attn(sino, attn, acq_model): + """Get attn ASM from sino, attn image and acq model""" + if attn is None: + return None + asm_attn = pet.AcquisitionSensitivityModel(attn, acq_model) + # temporary fix pending attenuation offset fix in STIR: + # converting attenuation into 'bin efficiency' + asm_attn.set_up(sino) + bin_eff = pet.AcquisitionData(sino) + bin_eff.fill(1.0) + asm_attn.unnormalise(bin_eff) + asm_attn = pet.AcquisitionSensitivityModel(bin_eff) + return asm_attn + + +def get_acq_model(ref, sino): + """Get acquisition model.""" + if not use_gpu: + acq_model = sirf.Reg.AcquisitionModelUsingRayTracingMatrix() + else: + acq_model = sirf.Reg.AcquisitionModelUsingNiftyPET() + acq_model.set_use_truncation(True) + acq_model.set_cuda_verbosity(0) + + # Add randoms if desired + if rands_file: + rands = get_sinogram(rands_file) + acq_model.set_background_term(rands) + + return acq_model + + +def update_asm(acq_model, attn, asm_norm, sino, ref): + """Update acquisition sensitivity models.""" + + # Create attn ASM if necessary + asm_attn = get_asm_attn(attn) + + asm = None + if asm_norm and asm_attn: + asm = pet.AcquisitionSensitivityModel(asm_norm, asm_attn) + elif asm_norm: + asm = asm_norm + elif asm_attn: + asm = asm_attn + if asm: + acq_model.set_acquisition_sensitivity(asm) + + # Set up + acq_model.set_up(sino, ref) + + +def update_alpha_image(reference, current_estimate, + cpg_2_dvf_converter, + resampler, bsplines, alpha): + """Update alpha (b-spline CPG) in image space.""" + # Convert alpha to dvf + dvf = cpg_2_dvf_converter.forward(alpha) + # Get current estimate + transformed_estimate = resampler.forward(dvf, current_estimate) + # Get gradient of alpha + grad_obj_fn = (reference - transformed_estimate) * transformed_estimate + grad_dvf = resampler.backward(grad_obj_fn) + grad_alpha = cpg_2_dvf_converter.backward(grad_dvf) + return grad_alpha + + +def update_alpha_sino(reference, current_estimate, + cpg_2_dvf_converter, + resampler, bsplines, alpha, + attn, acq_model, asm_norm, sino): + """Update alpha (b-spline CPG) in sinogram space.""" + # Convert alpha to dvf + dvf = cpg_2_dvf_converter.forward(alpha) + # Get current estimate (need to resample both emission and attn) + transformed_estimate = resampler.forward(dvf, current_estimate) + transformed_attn = resampler.forward(dvf, attn) + + # Update the acq_model and objective function + # with the updated attenuation image. + update_asm(acq_model, transformed_attn, asm_norm, sino, ref) + + estimated_projection = acq_model.forward(transformed_estimate) + estimated_data = estimated_projection + acq_model.get_background_term() + grad_obj = -measured_data / estimated_data + 1 + grad_emission_image = acq_model.backward(grad_obj) + # New backward method + grad_dvf_em = resampler.backward(grad_emission_image) + # atn_acq_model - AcqModel for everything incl. attn + grad_attenuation_image = -acq_model.backward( + grad_obj * estimated_projection) + # New backward but now with attn. image + grad_dvf_atn = resampler.backward(grad_attenuation_image) + grad_dvf = grad_dvf_em + grad_dvf_atn + grad_alpha = cpg_2_dvf_converter.backward(grad_dvf) + return grad_alpha + + +def main(): + """Do the main function.""" + + # Read input files + ref = get_image(ref_file, ref_eng) + flo = get_image(flo_file, ref_eng) + template_sino = get_sinogram(sino_file) + attn = get_image(attn_file) if attn_file else None + asm_norm = get_asm_norm(norm_file) + + # Ability to convert between CPG and DVF + cpg_2_dvf_converter = get_cpg_2_dvf_converter(ref) + + # Create DVF and CPG (called alpha) + dvf = get_dvf(ref) + alpha = cpg_2_dvf_converter.backward(dvf) + + # We'll need a niftyreg resampler + nr = get_resampler(ref, flo) + + # And we'll need the ImageGradientWRTDeformationTimesImage class + # (which we'll call the resampler) + resampler = sirf.Reg.ImageGradientWRTDeformationTimesImage() + resampler.set_resampler(nr) + + # Get likelihood + acq_model = get_acq_model(ref, template_sino) + update_asm(acq_model, attn, asm_norm, template_sino, ref) + + # Registered image starts as copy of reference image + registered_im = ref.copy() + + sino = acq_model.forward(ref) + simulated_sino = sino.copy() + + # Optimisation loop + for iter in tqdm(range(num_iters)): + + # update alpha (b-spline CPG) in image or sinogram space + if regularisation_space == 'image': + alpha = update_alpha_image(reference, current_estimate, + cpg_2_dvf_converter, + resampler, alpha) + else: + alpha = update_alpha_sino(reference, current_estimate, + cpg_2_dvf_converter, + resampler, alpha, + attn, acq_model, asm_norm, sino) + + alpha + step * grad_alpha + + +if __name__ == "__main__": + main() From 22b8c9cfefa19330f7387a9603a59e510389d121 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 23 Jul 2020 18:32:59 +0100 Subject: [PATCH 40/42] grad wrt dvf specify all arguments --- .../ImageGradientWRTDeformationTimesImage.cpp | 21 +++++--- src/Registration/cReg/NiftyResample.cpp | 33 ++++++++---- src/Registration/cReg/cReg.cpp | 53 ++++++++++++++----- .../ImageGradientWRTDeformationTimesImage.h | 17 ++++-- .../cReg/include/sirf/Reg/NiftyResample.h | 11 +++- src/Registration/cReg/include/sirf/Reg/cReg.h | 8 +-- src/Registration/cReg/tests/test_cReg.cpp | 6 +-- src/Registration/pReg/Reg.py | 23 +++++--- src/Registration/pReg/tests/test_pReg.py | 8 +-- 9 files changed, 126 insertions(+), 54 deletions(-) diff --git a/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp index 56719e8c6..90fb7ba7c 100644 --- a/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp +++ b/src/Registration/cReg/ImageGradientWRTDeformationTimesImage.cpp @@ -45,7 +45,8 @@ template void ImageGradientWRTDeformationTimesImage:: forward(std::shared_ptr im_sptr, - const std::shared_ptr > deformation_sptr) + const std::shared_ptr > deformation_sptr, + const std::shared_ptr in_sptr) { _resampler_sptr->clear_transformations(); _resampler_sptr->add_transformation(deformation_sptr); @@ -56,7 +57,8 @@ forward(std::shared_ptr im_sptr, template std::shared_ptr ImageGradientWRTDeformationTimesImage:: -forward(const std::shared_ptr > deformation_sptr) +forward(const std::shared_ptr > deformation_sptr, + const std::shared_ptr in_sptr) { _resampler_sptr->clear_transformations(); _resampler_sptr->add_transformation(deformation_sptr); @@ -67,17 +69,24 @@ forward(const std::shared_ptr > defo template void ImageGradientWRTDeformationTimesImage:: -backward(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr) +backward(std::shared_ptr > &output_deformation_sptr, + const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr) { - _resampler_sptr->get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, image_to_multiply_sptr); + _resampler_sptr->get_image_gradient_wrt_deformation_times_image( + output_deformation_sptr, input_deformation_sptr, image_for_gradient_sptr, image_to_multiply_sptr); } template std::shared_ptr > ImageGradientWRTDeformationTimesImage:: -backward(const std::shared_ptr image_to_multiply_sptr) +backward(const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr) { - return std::move(_resampler_sptr->get_image_gradient_wrt_deformation_times_image(image_to_multiply_sptr)); + return std::move(_resampler_sptr->get_image_gradient_wrt_deformation_times_image( + input_deformation_sptr, image_for_gradient_sptr, image_to_multiply_sptr)); } namespace sirf { diff --git a/src/Registration/cReg/NiftyResample.cpp b/src/Registration/cReg/NiftyResample.cpp index 84f9e5b34..b35668fc8 100644 --- a/src/Registration/cReg/NiftyResample.cpp +++ b/src/Registration/cReg/NiftyResample.cpp @@ -356,26 +356,38 @@ void NiftyResample:: get_image_gradient_wrt_deformation_times_image( std::shared_ptr > &output_deformation_sptr, - const std::shared_ptr image_to_multiply_sptr) + const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr) { // Call the set up set_up(); + // Check metadata of input deformation matches that used for set up + if (!NiftiImageData::do_nifti_image_metadata_match(*this->_deformation_sptr,*input_deformation_sptr,false)) + throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image: Metadata of input deformation should match that used for set up."); + // Only tested for linear interpolation if (this->_interpolation_type != Resample::LINEAR) throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image only implemented for linear interpolation"); // Not implemented for complex images - if (this->_floating_image_niftis.is_complex() || image_to_multiply_sptr->is_complex()) + if (image_for_gradient_sptr->is_complex() || image_to_multiply_sptr->is_complex()) throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image not yet implemented for complex images"); - // Get real part of floating image - std::shared_ptr > floating_sptr = this->_floating_image_niftis.real()->clone(); + // Get image for gradient as nifti and check metadata matches + std::shared_ptr > image_for_gradient_as_nifti_sptr = + std::make_shared >(*image_for_gradient_sptr); + if (!NiftiImageData::do_nifti_image_metadata_match(*image_for_gradient_as_nifti_sptr,*this->_floating_image_niftis.real(),false)) + throw std::runtime_error("NiftyResample::get_image_gradient_wrt_deformation_times_image: Metadata of image for gradient should match that used for set up."); + + // Clone input deformation as niftyreg method not marked const + auto input_deformation_clone_sptr = input_deformation_sptr->clone(); // Get image gradient - reg_getImageGradient(floating_sptr->get_raw_nifti_sptr().get(), + reg_getImageGradient(image_for_gradient_as_nifti_sptr->get_raw_nifti_sptr().get(), output_deformation_sptr->get_raw_nifti_sptr().get(), - this->_deformation_sptr->get_raw_nifti_sptr().get(), + input_deformation_clone_sptr->get_raw_nifti_sptr().get(), nullptr, this->_interpolation_type, this->_padding_value, @@ -389,20 +401,23 @@ get_image_gradient_wrt_deformation_times_image( output_deformation_sptr->multiply_tensor_component(i, image_to_multiply_sptr); // divide by spacing to get to mm output_deformation_sptr->multiply_tensor_component(i, 1.f/im_spacing[i+1]); - } } template std::shared_ptr > NiftyResample:: -get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr) +get_image_gradient_wrt_deformation_times_image( + const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr) { // Call the set up set_up(); std::shared_ptr > output_deformation_sptr = this->_deformation_sptr->clone(); - get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, image_to_multiply_sptr); + get_image_gradient_wrt_deformation_times_image(output_deformation_sptr, input_deformation_sptr, + image_for_gradient_sptr, image_to_multiply_sptr); return std::move(output_deformation_sptr); } diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index b7f8a64c7..27090a051 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -736,7 +736,7 @@ void* cReg_ImGradWRTDef_set_resampler(const void* ptr, const void* resampler_ptr } extern "C" void* cReg_ImGradWRTDef_forward_in_place( - const void* ptr, const void* deformation_ptr, const void* out_ptr) + const void* ptr, const void* deformation_ptr, const void* in_ptr, const void* out_ptr) { try { ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = @@ -744,18 +744,21 @@ void* cReg_ImGradWRTDef_forward_in_place( // Get deformation std::shared_ptr > deformation_sptr; getObjectSptrFromHandle >(deformation_ptr, deformation_sptr); + // input image + std::shared_ptr in_sptr; + getObjectSptrFromHandle(in_ptr, in_sptr); // Out std::shared_ptr out_sptr; getObjectSptrFromHandle(out_ptr, out_sptr); // Do it. - im_grad_wrt_def_time_im.forward(out_sptr, deformation_sptr); + im_grad_wrt_def_time_im.forward(out_sptr, deformation_sptr, in_sptr); return new DataHandle; } CATCH; } extern "C" void* cReg_ImGradWRTDef_forward( - const void* ptr, const void* deformation_ptr) + const void* ptr, const void* deformation_ptr, const void* in_ptr) { try { ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = @@ -763,43 +766,65 @@ void* cReg_ImGradWRTDef_forward( // Get deformation std::shared_ptr > deformation_sptr; getObjectSptrFromHandle >(deformation_ptr, deformation_sptr); + // input image + std::shared_ptr in_sptr; + getObjectSptrFromHandle(in_ptr, in_sptr); // Do it. - auto out_sptr = im_grad_wrt_def_time_im.forward(deformation_sptr); + auto out_sptr = im_grad_wrt_def_time_im.forward(deformation_sptr, in_sptr); return newObjectHandle(out_sptr); } CATCH; } extern "C" void* cReg_ImGradWRTDef_backward_in_place( - const void* ptr, const void* image_ptr, const void* out_ptr) + const void* ptr, const void* input_deformation_ptr, + const void* image_for_gradient_ptr, + const void* image_to_multiply_ptr, + const void* out_ptr) { try { ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = objectFromHandle >(ptr); - // Get image - std::shared_ptr image_sptr; - getObjectSptrFromHandle(image_ptr, image_sptr); + // Get image for multiplication + std::shared_ptr image_to_multiply_sptr; + getObjectSptrFromHandle(image_to_multiply_ptr, image_to_multiply_sptr); + // Get image for gradient + std::shared_ptr image_for_gradient_sptr; + getObjectSptrFromHandle(image_for_gradient_ptr, image_for_gradient_sptr); + // Input dvf + std::shared_ptr > input_deformation_sptr; + getObjectSptrFromHandle >(input_deformation_ptr, input_deformation_sptr); // Out dvf (might be null pointer) std::shared_ptr > out_sptr; getObjectSptrFromHandle >(out_ptr, out_sptr); // Do it. - im_grad_wrt_def_time_im.backward(out_sptr, image_sptr); + im_grad_wrt_def_time_im.backward(out_sptr, input_deformation_sptr, + image_for_gradient_sptr, image_to_multiply_sptr); return new DataHandle; } CATCH; } extern "C" void* cReg_ImGradWRTDef_backward( - const void* ptr, const void* image_ptr) + const void* ptr, + const void* input_deformation_ptr, + const void* image_for_gradient_ptr, + const void* image_to_multiply_ptr) { try { ImageGradientWRTDeformationTimesImage& im_grad_wrt_def_time_im = objectFromHandle >(ptr); - // Get image - std::shared_ptr image_sptr; - getObjectSptrFromHandle(image_ptr, image_sptr); + // Get image for multiplication + std::shared_ptr image_to_multiply_sptr; + getObjectSptrFromHandle(image_to_multiply_ptr, image_to_multiply_sptr); + // Get image for gradient + std::shared_ptr image_for_gradient_sptr; + getObjectSptrFromHandle(image_for_gradient_ptr, image_for_gradient_sptr); + // Input dvf + std::shared_ptr > input_deformation_sptr; + getObjectSptrFromHandle >(input_deformation_ptr, input_deformation_sptr); // Do it. - auto out_sptr = im_grad_wrt_def_time_im.backward(image_sptr); + auto out_sptr = im_grad_wrt_def_time_im.backward(input_deformation_sptr, image_for_gradient_sptr, image_to_multiply_sptr); return newObjectHandle(out_sptr); } CATCH; diff --git a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h index a714109e0..1ac320ea3 100644 --- a/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h +++ b/src/Registration/cReg/include/sirf/Reg/ImageGradientWRTDeformationTimesImage.h @@ -56,20 +56,29 @@ class ImageGradientWRTDeformationTimesImage void set_resampler(const std::shared_ptr > resampler_sptr); /// Forward in place (resample image) - virtual void forward(std::shared_ptr im_sptr, const std::shared_ptr > deformation_sptr); + virtual void forward(std::shared_ptr out_sptr, const std::shared_ptr > deformation_sptr, const std::shared_ptr in_sptr); /// Forward (resample image) - virtual std::shared_ptr forward(const std::shared_ptr > deformation_sptr); + virtual std::shared_ptr forward(const std::shared_ptr > deformation_sptr, const std::shared_ptr in_sptr); /// Backward in place (get image gradient wrt transformation) - virtual void backward(std::shared_ptr > &output_transformation_sptr, const std::shared_ptr image_to_multiply_sptr); + virtual void backward( + std::shared_ptr > &output_transformation_sptr, + const std::shared_ptr > &deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr); /// Backward (get image gradient wrt transformation) - virtual std::shared_ptr > backward(const std::shared_ptr image_to_multiply_sptr); + virtual std::shared_ptr > backward( + const std::shared_ptr > &deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr); private: /// Resampler std::shared_ptr > _resampler_sptr; + /// Deformation + std::shared_ptr > _deformation_sptr; }; } diff --git a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h index 3cf4217b1..95c10e123 100644 --- a/src/Registration/cReg/include/sirf/Reg/NiftyResample.h +++ b/src/Registration/cReg/include/sirf/Reg/NiftyResample.h @@ -124,10 +124,17 @@ class NiftyResample : public Resample virtual void adjoint(std::shared_ptr output_sptr, const std::shared_ptr input_sptr); /// Get image gradient wrt transformation in place - virtual void get_image_gradient_wrt_deformation_times_image(std::shared_ptr > &output_deformation_sptr, const std::shared_ptr image_to_multiply_sptr); + virtual void get_image_gradient_wrt_deformation_times_image( + std::shared_ptr > &output_deformation_sptr, + const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr); /// Get image gradient wrt transformation - virtual std::shared_ptr > get_image_gradient_wrt_deformation_times_image(const std::shared_ptr image_to_multiply_sptr); + virtual std::shared_ptr > get_image_gradient_wrt_deformation_times_image( + const std::shared_ptr > &input_deformation_sptr, + const std::shared_ptr &image_for_gradient_sptr, + const std::shared_ptr &image_to_multiply_sptr); protected: diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 28419385f..5578a1c70 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -91,10 +91,10 @@ extern "C" { // ImageGradientWRTDeformationTimesImage void* cReg_ImGradWRTDef_set_resampler(const void* ptr, const void* resampler_ptr); - void* cReg_ImGradWRTDef_forward_in_place(const void* ptr, const void* deformation_ptr, const void* out_ptr); - void* cReg_ImGradWRTDef_forward(const void* ptr, const void* deformation_ptr); - void* cReg_ImGradWRTDef_backward_in_place(const void* ptr, const void* image_ptr, const void* out_ptr); - void* cReg_ImGradWRTDef_backward(const void* ptr, const void* image_ptr); + void* cReg_ImGradWRTDef_forward_in_place(const void* ptr, const void* deformation_ptr, const void* in_ptr, const void* out_ptr); + void* cReg_ImGradWRTDef_forward(const void* ptr, const void* deformation_ptr, const void* in_ptr); + void* cReg_ImGradWRTDef_backward_in_place(const void* ptr, const void* input_deformation_ptr, const void* image_for_gradient_ptr, const void* image_to_multiply_ptr, const void* out_ptr); + void* cReg_ImGradWRTDef_backward(const void* ptr, const void* input_deformation_ptr, const void* image_for_gradient_ptr, const void* image_to_multiply_ptr); // Registration void* cReg_Registration_process(void* ptr); diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 14a27160b..97ea65ba9 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1342,7 +1342,7 @@ int main(int argc, char* argv[]) resampler.set_resampler(nr_sptr); // lambda hat is forward of lambda - resampler.forward(lambda_hat_sptr, deformation_sptr); + resampler.forward(lambda_hat_sptr, deformation_sptr, lambda_sptr); // lambda tilde is copy of lambda hat with all voxels = cnst std::shared_ptr > lambda_tilde_sptr = lambda_hat_sptr->clone(); @@ -1350,7 +1350,7 @@ int main(int argc, char* argv[]) lambda_tilde_sptr->fill(fill_val); // img grad wrt dvf times image - auto dvf1_sptr = resampler.backward(lambda_tilde_sptr); + auto dvf1_sptr = resampler.backward(deformation_sptr, lambda_sptr, lambda_tilde_sptr); // dvf2 is a clone of dvf1. initially filled with zeroes auto dvf2_sptr = dvf1_sptr->clone(); @@ -1371,7 +1371,7 @@ int main(int argc, char* argv[]) d_shifted_sptr->fill(*deformation_sptr); (*d_shifted_sptr)(u*lambda_hat_sptr->get_num_voxels()+i) += epsilon; std::shared_ptr > res_forward = - std::dynamic_pointer_cast > (resampler.forward(d_shifted_sptr)); + std::dynamic_pointer_cast > (resampler.forward(d_shifted_sptr, lambda_sptr)); *d_lambda_times_rand_val_sptr = (*res_forward - *lambda_hat_sptr); *d_lambda_times_rand_val_sptr *= (fill_val/epsilon); dvf2_sptr->add_to_tensor_component(u, d_lambda_times_rand_val_sptr); diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index a2b0540fb..47dac5976 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -888,37 +888,44 @@ def set_resampler(self, resampler): self.handle, resampler.handle)) self.output_of_forward_method = resampler.reference_image - def forward(self, deformation, out=None): + def forward(self, deformation, in_image, out=None): """Forward (forward resample with given deformation).""" assert_validity(deformation, NiftiImageData3DDeformation) + assert_validity(in_image, SIRF.ImageData) # If we need to create the output if out is None: out = self.output_of_forward_method.same_object() out.handle = pyreg.cReg_ImGradWRTDef_forward( - self.handle, deformation.handle) + self.handle, deformation.handle, in_image.handle) check_status(out.handle) return out # If in place else: assert_validity(out, SIRF.ImageData) try_calling(pyreg.cReg_ImGradWRTDef_forward_in_place( - self.handle, deformation.handle, out.handle)) + self.handle, deformation.handle, in_image.handle, out.handle)) - def backward(self, image, out=None): + def backward(self, deformation, image_for_gradient, + image_for_multiplication, out=None): """Backward (get im grad wrt deformation times image).""" - assert_validity(image, SIRF.ImageData) + assert_validity(deformation, NiftiImageData3DDeformation) + assert_validity(image_for_gradient, SIRF.ImageData) + assert_validity(image_for_multiplication, SIRF.ImageData) # If we need to create the output if out is None: out = NiftiImageData3DDeformation() out.handle = pyreg.cReg_ImGradWRTDef_backward( - self.handle, image.handle) + self.handle, deformation.handle, + image_for_gradient.handle, + image_for_multiplication.handle) check_status(out.handle) return out # If in place else: assert_validity(out, NiftiImageData3DDeformation) - try_calling(pyreg.cReg_ImGradWRTDef_backward( - self.handle, image.handle, out.handle)) + try_calling(pyreg.cReg_ImGradWRTDef_forward_in_place( + self.handle, deformation.handle, image_for_gradient.handle, + image_for_multiplication.handle, out.handle)) class _Registration(ABC): diff --git a/src/Registration/pReg/tests/test_pReg.py b/src/Registration/pReg/tests/test_pReg.py index f2a1666a8..c56c33adf 100644 --- a/src/Registration/pReg/tests/test_pReg.py +++ b/src/Registration/pReg/tests/test_pReg.py @@ -1085,7 +1085,7 @@ def try_cgp_dvf_conversion(na): # Check the adjoint is truly the adjoint with: | - | / 0.5*(||+||) < epsilon cpg_2_dvf_converter._set_up_for_adjoint_test(dvf, dvf_to_cpg) - if not is_operator_adjoint(cpg_2_dvf_converter): + if not is_operator_adjoint(cpg_2_dvf_converter, verbose=False): raise AssertionError("ControlPointGridToDeformationConverter::adjoint() failed") time.sleep(0.5) @@ -1129,7 +1129,7 @@ def try_im_grad_wrt_def_times_im(): resampler.set_resampler(nr) # lambda hat is forward of lambda - resampler.forward(deformation, out=lambda_hat) + resampler.forward(deformation, lambda_im, out=lambda_hat) # lambda tilde is copy of lambda hat with all voxels = cnst lambda_tilde = lambda_hat.deep_copy() @@ -1137,7 +1137,7 @@ def try_im_grad_wrt_def_times_im(): lambda_tilde.fill(fill_val) # img grad wrt dvf times image - dvf1 = resampler.backward(lambda_tilde) + dvf1 = resampler.backward(deformation, lambda_im, lambda_tilde) # dvf2 is a clone of dvf1. initially filled with zeroes dvf2 = dvf1.deep_copy() @@ -1163,7 +1163,7 @@ def try_im_grad_wrt_def_times_im(): d_shifted_arr[ix, iy, iz, 0, iu] += epsilon d_shifted.fill(d_shifted_arr) - res_forward = resampler.forward(d_shifted) + res_forward = resampler.forward(d_shifted, lambda_im) d_lambda_times_rand_val = res_forward - lambda_hat d_lambda_times_rand_val *= (fill_val / epsilon) From a029c2c8a7ee05ae5a3efa992f1e9e1ebdcdede8 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 23 Jul 2020 22:19:29 +0100 Subject: [PATCH 41/42] debugging registration in image space --- examples/Python/PETMR/registration.py | 240 ++++++++++++------ ...ControlPointGridToDeformationConverter.cpp | 8 +- src/Registration/cReg/cReg.cpp | 6 +- .../ControlPointGridToDeformationConverter.h | 3 +- src/Registration/cReg/tests/test_cReg.cpp | 4 +- src/Registration/pReg/Reg.py | 2 +- 6 files changed, 174 insertions(+), 89 deletions(-) diff --git a/examples/Python/PETMR/registration.py b/examples/Python/PETMR/registration.py index ca8a5f925..ef0bba18e 100644 --- a/examples/Python/PETMR/registration.py +++ b/examples/Python/PETMR/registration.py @@ -4,21 +4,29 @@ registration [--help | options] Options: - -R , --ref= reference image - -F , --flo= floating image - --cpg_downsample= factor to downsample control point grid spacing - relative to reference image (e.g., 2 will mean - cpg spacing will be double that of reference - image) [default: 2] - --templ_sino= template sinogram - --rands= randoms sinogram - --attn= attenuation image - --norm= ECAT8 normalisation file - --num_iters= number of iterations [default: 10] - --num_subsets= number of subsets for PET projection [default: 21] - --space= space in which to perform registration (image or - sinogram) [default: sinogram] - --gpu use GPU + -R , --ref= reference image + -F , --flo= floating image + --num_iters= number of iterations [default: 10] + -s , --steps= number of steepest descent steps [default: 3] + --optimal use locally optimal steepest ascent + --cpg_downsample= factor to downsample control point grid + spacing relative to reference image (e.g., 2 + will mean cpg spacing will be double that of + reference image) [default: 2] + --space= space in which to perform registration (image + or sinogram) [default: sinogram] + -o , --out_prefix= registered file prefix [default: registered] + -d , --dvf_prefix= dvf file prefix [default: dvf] + +Options when registering in sinogram space: + --templ_sino= template sinogram + --rands= randoms sinogram + --attn= attenuation image + --norm= ECAT8 normalisation file + + --num_subsets= number of subsets for PET projection + [default: 21] + --gpu use GPU (requires NiftyPET projector) """ # SyneRBI Synergistic Image Reconstruction Framework (SIRF) @@ -53,42 +61,59 @@ # process command-line options ref_file = args['--ref'] flo_file = args['--flo'] -ref_eng = args['--ref_eng'] -flo_eng = args['--flo_eng'] cpg_downsample_factor = float(args['--cpg_downsample']) sino_file = args['--templ_sino'] rands_file = args['--rands'] attn_file = args['--attn'] norm_file = args['--norm'] -num_iters = float(args['--num_iters']) -num_subsets = float(args['--num_subsets']) +num_iters = int(args['--num_iters']) +opt = args['--optimal'] +steps = int(args['--steps']) +out_prefix = args['--out_prefix'] +dvf_prefix = args['--dvf_prefix'] +num_subsets = int(args['--num_subsets']) registration_space = args['--space'] if registration_space not in {'sinogram', 'image'}: raise error("Unknown registration space: " + registration_space) use_gpu = True if args['--gpu'] else False +if opt: + import scipy.optimize -def check_file_exists(fname): + +def check_file_exists(filename): """Check file exists. Else, throw error.""" if not os.path.isfile(filename): raise error("File not found: " + filename) -def get_image(filename): +def get_image(filename, required): """Get an image from its filename.""" + if not required and filename is None: + return None check_file_exists(filename) - return sirf.STIR.ImageData(filename) + if registration_space == 'image': + try: + im = sirf.Reg.ImageData(filename) + except: + im = sirf.STIR.ImageData(filename) + else: + im = sirf.STIR.ImageData(filename) + return im -def get_sinogram(filename): +def get_sinogram(filename, required): """Get an sinogram from its filename.""" + if not required and filename is None: + return None check_file_exists(filename) return sirf.STIR.AcquisitionData(filename) def get_cpg_2_dvf_converter(ref): """Get CPG 2 DVF converter.""" - cpg_spacing = ref.get_voxel_sizes()[1:4] * cpg_downsample_factor + im_spacing = ref.get_geometrical_info().get_spacing() + cpg_spacing = [spacing * cpg_downsample_factor for spacing in im_spacing] cpg_2_dvf_converter = sirf.Reg.ControlPointGridToDeformationConverter() cpg_2_dvf_converter.set_cpg_spacing(cpg_spacing) cpg_2_dvf_converter.set_reference_image(ref) @@ -98,7 +123,8 @@ def get_cpg_2_dvf_converter(ref): def get_dvf(ref): """Get initial DVF.""" disp = sirf.Reg.NiftiImageData3DDisplacement() - disp.create_from_3D_image(ref) + ref_nii = sirf.Reg.NiftiImageData3D(ref) + disp.create_from_3D_image(ref_nii) disp.fill(0) dvf = sirf.Reg.NiftiImageData3DDeformation(disp) return dvf @@ -174,61 +200,98 @@ def update_asm(acq_model, attn, asm_norm, sino, ref): acq_model.set_up(sino, ref) -def update_alpha_image(reference, current_estimate, - cpg_2_dvf_converter, - resampler, bsplines, alpha): +def grad_alpha_image( + reference, floating, + current_estimate, + cpg_2_dvf_converter, + resampler, dvf): """Update alpha (b-spline CPG) in image space.""" - # Convert alpha to dvf - dvf = cpg_2_dvf_converter.forward(alpha) - # Get current estimate - transformed_estimate = resampler.forward(dvf, current_estimate) # Get gradient of alpha - grad_obj_fn = (reference - transformed_estimate) * transformed_estimate - grad_dvf = resampler.backward(grad_obj_fn) + grad_obj_fn = current_estimate - reference + # sirf.Reg.ImageData(grad_obj_fn).write( + # "grad_obj_fn") + + grad_dvf = resampler.backward(dvf, floating, grad_obj_fn) + # grad_dvf.write("grad_dvf") + # sirf.Reg.NiftiImageData3DDisplacement(grad_dvf).write("grad_dvf_as_disp") grad_alpha = cpg_2_dvf_converter.backward(grad_dvf) + + # Return grad alpha return grad_alpha -def update_alpha_sino(reference, current_estimate, - cpg_2_dvf_converter, - resampler, bsplines, alpha, - attn, acq_model, asm_norm, sino): +def grad_alpha_sino( + reference, estimate, + cpg_2_dvf_converter, + resampler, alpha, + attn, emission_acq_model, + attenuation_acq_model, asm_norm, sino): """Update alpha (b-spline CPG) in sinogram space.""" # Convert alpha to dvf dvf = cpg_2_dvf_converter.forward(alpha) # Get current estimate (need to resample both emission and attn) - transformed_estimate = resampler.forward(dvf, current_estimate) - transformed_attn = resampler.forward(dvf, attn) + transformed_estimate = resampler.forward(dvf, estimate) - # Update the acq_model and objective function + # Update the emission_acq_model # with the updated attenuation image. - update_asm(acq_model, transformed_attn, asm_norm, sino, ref) - - estimated_projection = acq_model.forward(transformed_estimate) - estimated_data = estimated_projection + acq_model.get_background_term() + if attn: + transformed_attn = resampler.forward(dvf, attn) + update_asm(emission_acq_model, transformed_attn, asm_norm, sino, ref) + + estimated_data = emission_acq_model.direct(transformed_estimate) + estimated_projection = \ + estimated_data - emission_acq_model.get_constant_term() + # Beware: below is -ve PoissonLogLikelihood, which is +ve KL. + # Just bear this in mind when thinking about maximisation or minimsation, + # especially when doing JRM (image reconstruction and registration at same + # time). grad_obj = -measured_data / estimated_data + 1 - grad_emission_image = acq_model.backward(grad_obj) + grad_emission_image = emission_acq_model.backward(grad_obj) # New backward method - grad_dvf_em = resampler.backward(grad_emission_image) - # atn_acq_model - AcqModel for everything incl. attn - grad_attenuation_image = -acq_model.backward( - grad_obj * estimated_projection) - # New backward but now with attn. image - grad_dvf_atn = resampler.backward(grad_attenuation_image) - grad_dvf = grad_dvf_em + grad_dvf_atn + grad_dvf_em = resampler.backward(dvf, emission, grad_emission_image) + if attn: + grad_attenuation_image = -emission_acq_model.backward( + grad_obj * estimated_projection) + # New backward but now with attn. image + grad_dvf_atn = resampler.backward(dvf, attn, grad_attenuation_image) + grad_dvf = grad_dvf_em + grad_dvf_atn + else: + grad_dvf = grad_dvf_em grad_alpha = cpg_2_dvf_converter.backward(grad_dvf) - return grad_alpha + + # Return grad alpha and the current estimate + return [grad_alpha, transformed_estimate] + + +# def get_initial_tau(alpha, grad_alpha): +# """Get initial tau""" +# lmd_max = 2*grad_alpha_0.norm()/alpha_0.norm() +# tau = 1/lmd_max +# return tau + + +# def update_tau(alpha, alpha_0, grad_alpha, grad_alpha_0, max_step): +# """Get updated tau.""" +# d_alpha = alpha - alpha_0 +# d_grad = grad_alpha - grad_alpha_0 +# # dg = H di, hence a rough idea about lmd_max is given by +# lmd_max = 2*d_grad.norm()/d_alpha.norm() +# # alternative smaller estimate for lmd_max is +# # lmd_max = -2*dg.dot(di)/di.dot(di) +# tau = min(tau_0, 1/lmd_max) def main(): """Do the main function.""" # Read input files - ref = get_image(ref_file, ref_eng) - flo = get_image(flo_file, ref_eng) - template_sino = get_sinogram(sino_file) - attn = get_image(attn_file) if attn_file else None - asm_norm = get_asm_norm(norm_file) + ref = get_image(ref_file, True) + flo = get_image(flo_file, True) + + if registration_space == 'sinogram': + template_sino = get_sinogram(sino_file, True) + attn = get_image(attn_file, False) + asm_norm = get_asm_norm(norm_file) # Ability to convert between CPG and DVF cpg_2_dvf_converter = get_cpg_2_dvf_converter(ref) @@ -245,31 +308,54 @@ def main(): resampler = sirf.Reg.ImageGradientWRTDeformationTimesImage() resampler.set_resampler(nr) - # Get likelihood - acq_model = get_acq_model(ref, template_sino) - update_asm(acq_model, attn, asm_norm, template_sino, ref) + if registration_space == 'sinogram': + # Get acquisition models. We'll need 2 - one for emission (which will + # contain all corrections: randoms, norms, up-to-date attn), and the + # other for attenuation (blank just for line integrals). + emission_acq_model = get_acq_model(ref, template_sino) + attenuation_acq_model = get_acq_model(ref, template_sino) + attenuation_acq_model.set_up(template_sino, ref) + update_asm(acq_model, attn, asm_norm, template_sino, ref) - # Registered image starts as copy of reference image - registered_im = ref.copy() + sino = acq_model.forward(ref) - sino = acq_model.forward(ref) - simulated_sino = sino.copy() + # Initial current estimate + current_estimate = resampler.forward(dvf, flo) + sirf.Reg.ImageData(current_estimate).write(out_prefix + "_0") + dvf.write(dvf_prefix + "_0") # Optimisation loop - for iter in tqdm(range(num_iters)): + for iter in range(num_iters): # update alpha (b-spline CPG) in image or sinogram space - if regularisation_space == 'image': - alpha = update_alpha_image(reference, current_estimate, - cpg_2_dvf_converter, - resampler, alpha) + # Image space + if registration_space == 'image': + grad_alpha = grad_alpha_image( + ref, flo, current_estimate, + cpg_2_dvf_converter, resampler, dvf) + # Sinogram space else: - alpha = update_alpha_sino(reference, current_estimate, - cpg_2_dvf_converter, - resampler, alpha, - attn, acq_model, asm_norm, sino) - - alpha + step * grad_alpha + grad_alpha = grad_alpha_sino( + ref, flo, current_estimate, cpg_2_dvf_converter, resampler, + dvf, attn, emission_acq_model, attenuation_acq_model, + asm_norm, sino) + + # update alpha + grad_alpha_max = grad_alpha.as_array().max() + if grad_alpha_max <= 0: + raise error("damn.") + print("Max alpha: " + str(alpha.as_array().max())) + print("Max grad_alpha: " + str(grad_alpha_max)) + alpha -= grad_alpha / grad_alpha.as_array().max() + + # Get current dvf and estimate + dvf = cpg_2_dvf_converter.forward(alpha) + resampler.forward(dvf, flo, current_estimate) + + sirf.Reg.ImageData(current_estimate).write( + out_prefix + "_" + str(iter+1)) + sirf.Reg.NiftiImageData3DDisplacement(dvf).write( + dvf_prefix + "_" + str(iter+1)) if __name__ == "__main__": diff --git a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp index 387b06353..f36f3e980 100644 --- a/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp +++ b/src/Registration/cReg/ControlPointGridToDeformationConverter.cpp @@ -54,9 +54,10 @@ set_cpg_spacing(const float spacing[3]) template void ControlPointGridToDeformationConverter:: -set_reference_image(const NiftiImageData &ref) +set_reference_image(const std::shared_ptr &ref_sptr) { - _template_ref_sptr = ref.clone(); + // Need a copy as we'll need a non-const version + _template_ref_sptr = std::make_shared >(*ref_sptr); } template @@ -65,9 +66,6 @@ ControlPointGridToDeformationConverter:: forward(const NiftiImageData3DBSpline &cpg) const { check_is_set_up(); -// NiftiImageData3DDeformation dvf; -// dvf.create_from_cpp(cpg, *_template_ref_sptr); -// return dvf; return cpg.get_as_deformation_field(*_template_ref_sptr); } diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index 27090a051..5d0c4fdad 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -684,9 +684,9 @@ void* cReg_CPG2DVF_set_ref_im(const void* converter_ptr, const void* ref_im_ptr) try { ControlPointGridToDeformationConverter& cpg_2_dvf_converter = objectFromHandle >(converter_ptr); - NiftiImageData& ref_im = - objectFromHandle >(ref_im_ptr); - cpg_2_dvf_converter.set_reference_image(ref_im); + std::shared_ptr ref_im_sptr; + getObjectSptrFromHandle(ref_im_ptr, ref_im_sptr); + cpg_2_dvf_converter.set_reference_image(ref_im_sptr); return new DataHandle; } CATCH; diff --git a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h index 464655a85..1a2cac63b 100644 --- a/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h +++ b/src/Registration/cReg/include/sirf/Reg/ControlPointGridToDeformationConverter.h @@ -34,6 +34,7 @@ limitations under the License. namespace sirf { // Forward declarations +class ImageData; template class NiftiImageData; template class NiftiImageData3DDeformation; template class NiftiImageData3DBSpline; @@ -57,7 +58,7 @@ class ControlPointGridToDeformationConverter void set_cpg_spacing(const float spacing[3]); /// Set reference image for generating dvfs - void set_reference_image(const NiftiImageData &ref); + void set_reference_image(const std::shared_ptr &ref_sptr); /// CPG to DVF NiftiImageData3DDeformation forward(const NiftiImageData3DBSpline &cpg) const; diff --git a/src/Registration/cReg/tests/test_cReg.cpp b/src/Registration/cReg/tests/test_cReg.cpp index 97ea65ba9..4c34d7fc5 100644 --- a/src/Registration/cReg/tests/test_cReg.cpp +++ b/src/Registration/cReg/tests/test_cReg.cpp @@ -1270,7 +1270,7 @@ int main(int argc, char* argv[]) *NiftiImageData::create_from_geom_info( geom_info,true, NREG_TRANS_TYPE::DISP_FIELD)); NiftiImageData3DDeformation dvf(disp); - NiftiImageData ref = *dvf.get_tensor_component(0); + std::shared_ptr > ref_sptr = dvf.get_tensor_component(0); // CPG spacing double the dvf spacing float cpg_spacing[3] = {4.f * spacing_dvf[0], 4.f * spacing_dvf[1], 4.f * spacing_dvf[2]}; @@ -1278,7 +1278,7 @@ int main(int argc, char* argv[]) // set up DVF<->CPG converter ControlPointGridToDeformationConverter cpg_2_dvf_converter; cpg_2_dvf_converter.set_cpg_spacing(cpg_spacing); - cpg_2_dvf_converter.set_reference_image(ref); + cpg_2_dvf_converter.set_reference_image(ref_sptr); // ok, now ready to do adjoint test using: // | - | / 0.5*(||+||) < epsilon diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index 47dac5976..6ae0a41e6 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -818,7 +818,7 @@ def set_cpg_spacing(self, spacing): def set_reference_image(self, ref_im): """Set reference image for generating dvfs.""" - assert_validity(ref_im, NiftiImageData3D) + assert_validity(ref_im, SIRF.ImageData) try_calling(pyreg.cReg_CPG2DVF_set_ref_im(self.handle, ref_im.handle)) def forward(self, cpg): From 89e292876bc4e3f471186427369c79180189d7b3 Mon Sep 17 00:00:00 2001 From: richard Date: Thu, 23 Jul 2020 22:20:25 +0100 Subject: [PATCH 42/42] kernel convolution in python --- src/Registration/cReg/cReg.cpp | 14 ++++++++++++++ src/Registration/cReg/include/sirf/Reg/cReg.h | 1 + src/Registration/pReg/Reg.py | 12 ++++++++++++ 3 files changed, 27 insertions(+) diff --git a/src/Registration/cReg/cReg.cpp b/src/Registration/cReg/cReg.cpp index 9a52feeb0..b24fdfd13 100644 --- a/src/Registration/cReg/cReg.cpp +++ b/src/Registration/cReg/cReg.cpp @@ -33,6 +33,7 @@ limitations under the License. #include "sirf/Reg/Transformation.h" #include "sirf/Reg/AffineTransformation.h" #include "sirf/Reg/Quaternion.h" +#include <_reg_tools.h> #ifdef SIRF_SPM #include "sirf/Reg/SPMRegistration.h" #endif @@ -494,6 +495,19 @@ void* cReg_NiftiImageData_are_equal_to_given_accuracy(void* im1_ptr, void* im2_p } CATCH; } + +extern "C" +void* cReg_NiftiImageData_kernel_convolution(void* im_ptr, const float sigma, const int type) +{ + try { + std::shared_ptr > im_sptr; + getObjectSptrFromHandle >(im_ptr, im_sptr); + im_sptr->kernel_convolution(sigma, static_cast(type)); + return new DataHandle; + } + CATCH; +} + // -------------------------------------------------------------------------------- // // NiftiImageData3DTensor // -------------------------------------------------------------------------------- // diff --git a/src/Registration/cReg/include/sirf/Reg/cReg.h b/src/Registration/cReg/include/sirf/Reg/cReg.h index 2172ff5d0..f8a052bf8 100644 --- a/src/Registration/cReg/include/sirf/Reg/cReg.h +++ b/src/Registration/cReg/include/sirf/Reg/cReg.h @@ -63,6 +63,7 @@ extern "C" { void* cReg_NiftiImageData_from_complex_ImageData_real_component(void* in_ptr); void* cReg_NiftiImageData_from_complex_ImageData_imag_component(void* in_ptr); void* cReg_NiftiImageData_are_equal_to_given_accuracy(void* im1_ptr, void* im2_ptr, const float accuracy); + void* cReg_NiftiImageData_kernel_convolution(void* im_ptr, const float sigma, const int type); // NiftiImageData3D diff --git a/src/Registration/pReg/Reg.py b/src/Registration/pReg/Reg.py index 745e4942f..0e9a46229 100644 --- a/src/Registration/pReg/Reg.py +++ b/src/Registration/pReg/Reg.py @@ -468,6 +468,18 @@ def get_inner_product(self, other): pyiutil.deleteDataHandle(handle) return inner_product + def kernel_convolution(self, sigma, convolution_type=2): + """Kernel convolution. + + convolution_type: + - MEAN_KERNEL = 0 + - LINEAR_KERNEL = 1 + - GAUSSIAN_KERNEL = 2 + - CUBIC_SPLINE_KERNEL = 3 + """ + try_calling(pyreg.cReg_NiftiImageData_kernel_convolution( + self.handle, float(sigma), int(convolution_type))) + @staticmethod def print_headers(to_print): """Print nifti header metadata of one or multiple nifti images."""