Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

image gradient wrt transformation #759

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
4fee4ab
implementation
Jul 1, 2020
1686c4d
current progress
Jul 2, 2020
b62b75e
works
Jul 2, 2020
6c69d09
add converter
Jul 2, 2020
c38877b
correct case
Jul 2, 2020
52d3e5b
converter works
Jul 2, 2020
0d2137c
exported to python. tests pass
Jul 2, 2020
d97ec68
add adjoint test
Jul 2, 2020
8e328dc
remove bspline from deformation constructor
Jul 3, 2020
e376367
updated
Jul 3, 2020
60ededc
c++ test works
Jul 4, 2020
65a81c2
create from cpp const
Jul 4, 2020
78263eb
Merge branch 'add_fill_method_from_float_array' into BSpline2
Jul 4, 2020
5808c1e
Merge branch 'construct_from_geom_info_improvement' into BSpline2
Jul 4, 2020
e6eab54
Merge branch 'correct_check_dimensions' into BSpline2
Jul 4, 2020
66cfe90
python test passes
Jul 4, 2020
2b72e67
Merge remote-tracking branch 'SyneRBI/master' into BSpline2
Jul 6, 2020
ec8d6fb
codacy changes
Jul 6, 2020
c127e55
Merge branch 'master' into BSpline2
Jul 7, 2020
64c0833
Merge remote-tracking branch 'SyneRBI/master' into BSpline2
Jul 7, 2020
788da8d
Merge branch 'master' into BSpline2
Jul 7, 2020
30691c0
image gradient wrt transformation
Jul 7, 2020
56b01cf
Merge remote-tracking branch 'SyneRBI/master' into BSpline2
Jul 8, 2020
0970888
Merge branch 'BSpline2' into ImageGradWRTTrans
Jul 8, 2020
bf448e5
multiply tensor component by scalar image
Jul 9, 2020
bab142e
add forward method
Jul 9, 2020
86a500e
image grad wrt DEF TIMES IM
Jul 9, 2020
d2f3e4d
use mean to compare images
Jul 9, 2020
ed1554b
give index access for separate values
Jul 9, 2020
564b427
add tensor maths with scalar image
Jul 9, 2020
abcd3f5
update functionality
Jul 9, 2020
fd311bf
update testing
Jul 9, 2020
448967a
codacy
Jul 9, 2020
fa4faac
fill from image
Jul 10, 2020
73627b6
correct /=
Jul 10, 2020
2af785e
update sto_xyz/sto_ijk when cropping image
Jul 10, 2020
6fded33
current attempt
Jul 10, 2020
58584c0
Merge branch 'master' into BSpline2
Jul 14, 2020
8faff7d
Merge branch 'BSpline2' into ImageGradWRTTrans
Jul 14, 2020
b8b4a8a
fix Reg python import
Jul 14, 2020
fda8f05
Merge branch 'BSpline2' into ImageGradWRTTrans
Jul 14, 2020
16d4d20
Merge branch 'no_install_niftymomo_headers' into BSpline2
Jul 17, 2020
c64a370
Merge branch 'update_niftymomo' into BSpline2
Jul 17, 2020
774e689
update to align with niftymomo
Jul 17, 2020
ce506d7
Merge branch 'BSpline2' into ImageGradWRTTrans
Jul 17, 2020
33ddd20
remove temp variable
Jul 17, 2020
a66f26d
Merge remote-tracking branch 'SyneRBI/master' into BSpline2
Jul 17, 2020
f85c7a5
Merge branch 'update_niftymomo' into BSpline2
Jul 17, 2020
97e3c4c
Merge branch 'BSpline2' into ImageGradWRTTrans
Jul 17, 2020
bafd881
divide by spacing
Jul 17, 2020
ad660d1
correct the test, only working for linear interpolation
Jul 17, 2020
52682a8
python partially implemented
Jul 22, 2020
ca8ae58
continue python implementation and test
Jul 22, 2020
b171ceb
finished
Jul 22, 2020
380f937
test everything
Jul 22, 2020
f1c1d4e
Merge remote-tracking branch 'SyneRBI/master' into ImageGradWRTTrans
Jul 22, 2020
fae6b20
undo comparison against image mean
Jul 22, 2020
d5d435f
registration attempt
Jul 22, 2020
3c34b4c
Merge remote-tracking branch 'SyneRBI/master' into ImageGradWRTTrans
Jul 23, 2020
f13a50a
Merge branch 'master' into ImageGradWRTTrans
Jul 23, 2020
22b8c9c
grad wrt dvf specify all arguments
Jul 23, 2020
a029c2c
debugging registration in image space
Jul 23, 2020
89e2928
kernel convolution in python
Jul 23, 2020
23cd688
Merge branch 'python_kernel_convolution' into ImageGradWRTTrans
Jul 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 362 additions & 0 deletions examples/Python/PETMR/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
"""Registration demo.

Usage:
registration [--help | options]

Options:
-R <file>, --ref=<file> reference image
-F <file>, --flo=<file> floating image
--num_iters=<val> number of iterations [default: 10]
-s <nstp>, --steps=<nstp> number of steepest descent steps [default: 3]
--optimal use locally optimal steepest ascent
--cpg_downsample=<val> factor to downsample control point grid
spacing relative to reference image (e.g., 2
will mean cpg spacing will be double that of
reference image) [default: 2]
--space=<str> space in which to perform registration (image
or sinogram) [default: sinogram]
-o <str>, --out_prefix=<str> registered file prefix [default: registered]
-d <str>, --dvf_prefix=<str> dvf file prefix [default: dvf]

Options when registering in sinogram space:
--templ_sino=<file> template sinogram
--rands=<file> randoms sinogram
--attn=<file> attenuation image
--norm=<file> ECAT8 normalisation file

--num_subsets=<val> number of subsets for PET projection
[default: 21]
--gpu use GPU (requires NiftyPET projector)
"""

# SyneRBI Synergistic Image Reconstruction Framework (SIRF)
# Copyright 2015 - 2020 University College London.
#
# This is software developed for the Collaborative Computational
# Project in Synergistic Reconstruction for Biomedical Imaging
# (formerly CCP PETMR)
# (http://www.ccpsynerbi.ac.uk/).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from docopt import docopt
from tqdm import tqdm
from sirf.Utilities import error
import sirf.STIR
import sirf.Reg

__version__ = '0.1.0'
args = docopt(__doc__, version=__version__)


# process command-line options
ref_file = args['--ref']
flo_file = args['--flo']
cpg_downsample_factor = float(args['--cpg_downsample'])
sino_file = args['--templ_sino']
rands_file = args['--rands']
attn_file = args['--attn']
norm_file = args['--norm']
num_iters = int(args['--num_iters'])
opt = args['--optimal']
steps = int(args['--steps'])
out_prefix = args['--out_prefix']
dvf_prefix = args['--dvf_prefix']
num_subsets = int(args['--num_subsets'])
registration_space = args['--space']
if registration_space not in {'sinogram', 'image'}:
raise error("Unknown registration space: " + registration_space)
use_gpu = True if args['--gpu'] else False

if opt:
import scipy.optimize


def check_file_exists(filename):
"""Check file exists. Else, throw error."""
if not os.path.isfile(filename):
raise error("File not found: " + filename)


def get_image(filename, required):
"""Get an image from its filename."""
if not required and filename is None:
return None
check_file_exists(filename)
if registration_space == 'image':
try:
im = sirf.Reg.ImageData(filename)
except:
im = sirf.STIR.ImageData(filename)
else:
im = sirf.STIR.ImageData(filename)
return im


def get_sinogram(filename, required):
"""Get an sinogram from its filename."""
if not required and filename is None:
return None
check_file_exists(filename)
return sirf.STIR.AcquisitionData(filename)


def get_cpg_2_dvf_converter(ref):
"""Get CPG 2 DVF converter."""
im_spacing = ref.get_geometrical_info().get_spacing()
cpg_spacing = [spacing * cpg_downsample_factor for spacing in im_spacing]
cpg_2_dvf_converter = sirf.Reg.ControlPointGridToDeformationConverter()
cpg_2_dvf_converter.set_cpg_spacing(cpg_spacing)
cpg_2_dvf_converter.set_reference_image(ref)
return cpg_2_dvf_converter


def get_dvf(ref):
"""Get initial DVF."""
disp = sirf.Reg.NiftiImageData3DDisplacement()
ref_nii = sirf.Reg.NiftiImageData3D(ref)
disp.create_from_3D_image(ref_nii)
disp.fill(0)
dvf = sirf.Reg.NiftiImageData3DDeformation(disp)
return dvf


def get_resampler(ref, flo):
"""Get resampler."""
nr = sirf.Reg.NiftyResample()
nr.set_reference_image(ref)
nr.set_floating_image(flo)
nr.set_interpolation_type_to_linear()
nr.set_padding_value(0)
return nr


def get_asm_norm(norm_file):
if norm_file is None:
return None
check_file_exists(norm_file)
asm_norm = pet.AcquisitionSensitivityModel(norm_file)
return asm_norm


def get_asm_attn(sino, attn, acq_model):
"""Get attn ASM from sino, attn image and acq model"""
if attn is None:
return None
asm_attn = pet.AcquisitionSensitivityModel(attn, acq_model)
# temporary fix pending attenuation offset fix in STIR:
# converting attenuation into 'bin efficiency'
asm_attn.set_up(sino)
bin_eff = pet.AcquisitionData(sino)
bin_eff.fill(1.0)
asm_attn.unnormalise(bin_eff)
asm_attn = pet.AcquisitionSensitivityModel(bin_eff)
return asm_attn


def get_acq_model(ref, sino):
"""Get acquisition model."""
if not use_gpu:
acq_model = sirf.Reg.AcquisitionModelUsingRayTracingMatrix()
else:
acq_model = sirf.Reg.AcquisitionModelUsingNiftyPET()
acq_model.set_use_truncation(True)
acq_model.set_cuda_verbosity(0)

# Add randoms if desired
if rands_file:
rands = get_sinogram(rands_file)
acq_model.set_background_term(rands)

return acq_model


def update_asm(acq_model, attn, asm_norm, sino, ref):
"""Update acquisition sensitivity models."""

# Create attn ASM if necessary
asm_attn = get_asm_attn(attn)

asm = None
if asm_norm and asm_attn:
asm = pet.AcquisitionSensitivityModel(asm_norm, asm_attn)
elif asm_norm:
asm = asm_norm
elif asm_attn:
asm = asm_attn
if asm:
acq_model.set_acquisition_sensitivity(asm)

# Set up
acq_model.set_up(sino, ref)


def grad_alpha_image(
reference, floating,
current_estimate,
cpg_2_dvf_converter,
resampler, dvf):
"""Update alpha (b-spline CPG) in image space."""
# Get gradient of alpha
grad_obj_fn = current_estimate - reference
# sirf.Reg.ImageData(grad_obj_fn).write(
# "grad_obj_fn")

grad_dvf = resampler.backward(dvf, floating, grad_obj_fn)
# grad_dvf.write("grad_dvf")
# sirf.Reg.NiftiImageData3DDisplacement(grad_dvf).write("grad_dvf_as_disp")
grad_alpha = cpg_2_dvf_converter.backward(grad_dvf)

# Return grad alpha
return grad_alpha


def grad_alpha_sino(
reference, estimate,
cpg_2_dvf_converter,
resampler, alpha,
attn, emission_acq_model,
attenuation_acq_model, asm_norm, sino):
"""Update alpha (b-spline CPG) in sinogram space."""
# Convert alpha to dvf
dvf = cpg_2_dvf_converter.forward(alpha)
# Get current estimate (need to resample both emission and attn)
transformed_estimate = resampler.forward(dvf, estimate)

# Update the emission_acq_model
# with the updated attenuation image.
if attn:
transformed_attn = resampler.forward(dvf, attn)
update_asm(emission_acq_model, transformed_attn, asm_norm, sino, ref)

estimated_data = emission_acq_model.direct(transformed_estimate)
estimated_projection = \
estimated_data - emission_acq_model.get_constant_term()
# Beware: below is -ve PoissonLogLikelihood, which is +ve KL.
# Just bear this in mind when thinking about maximisation or minimsation,
# especially when doing JRM (image reconstruction and registration at same
# time).
grad_obj = -measured_data / estimated_data + 1
grad_emission_image = emission_acq_model.backward(grad_obj)
# New backward method
grad_dvf_em = resampler.backward(dvf, emission, grad_emission_image)
if attn:
grad_attenuation_image = -emission_acq_model.backward(
grad_obj * estimated_projection)
# New backward but now with attn. image
grad_dvf_atn = resampler.backward(dvf, attn, grad_attenuation_image)
grad_dvf = grad_dvf_em + grad_dvf_atn
else:
grad_dvf = grad_dvf_em
grad_alpha = cpg_2_dvf_converter.backward(grad_dvf)

# Return grad alpha and the current estimate
return [grad_alpha, transformed_estimate]


# def get_initial_tau(alpha, grad_alpha):
# """Get initial tau"""
# lmd_max = 2*grad_alpha_0.norm()/alpha_0.norm()
# tau = 1/lmd_max
# return tau


# def update_tau(alpha, alpha_0, grad_alpha, grad_alpha_0, max_step):
# """Get updated tau."""
# d_alpha = alpha - alpha_0
# d_grad = grad_alpha - grad_alpha_0
# # dg = H di, hence a rough idea about lmd_max is given by
# lmd_max = 2*d_grad.norm()/d_alpha.norm()
# # alternative smaller estimate for lmd_max is
# # lmd_max = -2*dg.dot(di)/di.dot(di)
# tau = min(tau_0, 1/lmd_max)


def main():
"""Do the main function."""

# Read input files
ref = get_image(ref_file, True)
flo = get_image(flo_file, True)

if registration_space == 'sinogram':
template_sino = get_sinogram(sino_file, True)
attn = get_image(attn_file, False)
asm_norm = get_asm_norm(norm_file)

# Ability to convert between CPG and DVF
cpg_2_dvf_converter = get_cpg_2_dvf_converter(ref)

# Create DVF and CPG (called alpha)
dvf = get_dvf(ref)
alpha = cpg_2_dvf_converter.backward(dvf)

# We'll need a niftyreg resampler
nr = get_resampler(ref, flo)

# And we'll need the ImageGradientWRTDeformationTimesImage class
# (which we'll call the resampler)
resampler = sirf.Reg.ImageGradientWRTDeformationTimesImage()
resampler.set_resampler(nr)

if registration_space == 'sinogram':
# Get acquisition models. We'll need 2 - one for emission (which will
# contain all corrections: randoms, norms, up-to-date attn), and the
# other for attenuation (blank just for line integrals).
emission_acq_model = get_acq_model(ref, template_sino)
attenuation_acq_model = get_acq_model(ref, template_sino)
attenuation_acq_model.set_up(template_sino, ref)
update_asm(acq_model, attn, asm_norm, template_sino, ref)

sino = acq_model.forward(ref)

# Initial current estimate
current_estimate = resampler.forward(dvf, flo)
sirf.Reg.ImageData(current_estimate).write(out_prefix + "_0")
dvf.write(dvf_prefix + "_0")

# Optimisation loop
for iter in range(num_iters):

# update alpha (b-spline CPG) in image or sinogram space
# Image space
if registration_space == 'image':
grad_alpha = grad_alpha_image(
ref, flo, current_estimate,
cpg_2_dvf_converter, resampler, dvf)
# Sinogram space
else:
grad_alpha = grad_alpha_sino(
ref, flo, current_estimate, cpg_2_dvf_converter, resampler,
dvf, attn, emission_acq_model, attenuation_acq_model,
asm_norm, sino)

# update alpha
grad_alpha_max = grad_alpha.as_array().max()
if grad_alpha_max <= 0:
raise error("damn.")
print("Max alpha: " + str(alpha.as_array().max()))
print("Max grad_alpha: " + str(grad_alpha_max))
alpha -= grad_alpha / grad_alpha.as_array().max()

# Get current dvf and estimate
dvf = cpg_2_dvf_converter.forward(alpha)
resampler.forward(dvf, flo, current_estimate)

sirf.Reg.ImageData(current_estimate).write(
out_prefix + "_" + str(iter+1))
sirf.Reg.NiftiImageData3DDisplacement(dvf).write(
dvf_prefix + "_" + str(iter+1))


if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion src/Registration/cReg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ SET(SOURCES
"NiftiImageData3D.cpp"
"NiftiImageData3DTensor.cpp"
"NiftiImageData3DDeformation.cpp"
"NiftiImageData3DDisplacement.cpp")
"NiftiImageData3DDisplacement.cpp"
"NiftiImageData3DBSpline.cpp"
"ControlPointGridToDeformationConverter.cpp"
"ImageGradientWRTDeformationTimesImage.cpp"
)

# If we're also wrapping to python or matlab, include the c-files
IF(BUILD_PYTHON OR BUILD_MATLAB)
Expand Down
Loading