-
Notifications
You must be signed in to change notification settings - Fork 0
/
01_torch_projection_layer.py
110 lines (87 loc) · 4.5 KB
/
01_torch_projection_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
import utils
import parallelproj
import array_api_compat.torch as torch
from array_api_compat import device
from layers import LinearSingleChannelOperator, AdjointLinearSingleChannelOperator
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
# device variable (cpu or cuda) that determines whether calculations
# are performed on the cpu or cuda gpu
if parallelproj.cuda_present:
dev = 'cuda'
else:
dev = 'cpu'
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#--- setup the scanner / LOR geometry ---------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
# setup a line of response descriptor that describes the LOR start / endpoints of
# a "narrow" clinical PET scanner with 9 rings
lor_descriptor = utils.DemoPETScannerLORDescriptor(torch,
dev,
num_rings=2,
radial_trim=201)
# image properties
voxel_size = (2.66, 2.66, 2.66)
img_shape = (10, 10, 2 * lor_descriptor.scanner.num_modules)
projector = utils.RegularPolygonPETProjector(
lor_descriptor,
img_shape,
voxel_size,
views=torch.arange(0, lor_descriptor.num_views, 34, device=dev))
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
batch_size = 2
x = torch.rand((batch_size, 1) + projector.in_shape,
device=dev,
dtype=torch.float32,
requires_grad=True)
y = torch.rand((batch_size, ) + projector.out_shape,
device=dev,
dtype=torch.float32,
requires_grad=True)
# the LinearSingleChannelOperator and AdjointLinearSingleChannelOperator classes
# that subclass torch.autograd.Function
# are defined in layers.py -> have a look at the code there to see how the
# forward and backward passes are implemented
fwd_op_layer = LinearSingleChannelOperator.apply
adjoint_op_layer = AdjointLinearSingleChannelOperator.apply
f1 = fwd_op_layer(x, projector)
print('forward projection (Ax) .:', f1.shape, type(f1), device(f1))
b1 = adjoint_op_layer(y, projector)
print('back projection (A^T y) .:', b1.shape, type(b1), device(b1))
fb1 = adjoint_op_layer(fwd_op_layer(x, projector), projector)
print('back + forward projection (A^TAx) .:', fb1.shape, type(fb1),
device(fb1))
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
# define a dummy loss function
dummy_loss = (fb1**2).sum()
# trigger the backpropagation
dummy_loss.backward()
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
if dev == 'cpu':
print('skipping (slow) gradient checks on cpu')
else:
print('Running forward projection layer gradient test')
grad_test_fwd = torch.autograd.gradcheck(fwd_op_layer, (x, projector),
eps=1e-1,
atol=1e-4,
rtol=1e-4)
print('Running adjoint projection layer gradient test')
grad_test_fwd = torch.autograd.gradcheck(adjoint_op_layer, (y, projector),
eps=1e-1,
atol=1e-4,
rtol=1e-4)