-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
215 lines (168 loc) · 7.38 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from __future__ import annotations
import array_api_compat.torch as torch
import parallelproj
from array_api_compat import device
class LinearSingleChannelOperator(torch.autograd.Function):
"""
Function representing a linear operator acting on a mini batch of single channel images
"""
# see also: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
@staticmethod
def forward(ctx, x: torch.Tensor,
operator: parallelproj.LinearOperator) -> torch.Tensor:
"""forward pass of the linear operator
Parameters
----------
ctx : context object
that can be used to store information for the backward pass
x : torch.Tensor
mini batch of 3D images with dimension (batch_size, 1, num_voxels_x, num_voxels_y, num_voxels_z)
operator : parallelproj.LinearOperator
linear operator that can act on a single 3D image
Returns
-------
torch.Tensor
mini batch of 3D images with dimension (batch_size, opertor.out_shape)
"""
#https://pytorch.org/docs/stable/notes/extending.html#how-to-use
ctx.set_materialize_grads(False)
ctx.operator = operator
batch_size = x.shape[0]
y = torch.zeros((batch_size, ) + operator.out_shape,
dtype=x.dtype,
device=device(x))
# loop over all samples in the batch and apply linear operator
# to the first channel
for i in range(batch_size):
y[i, ...] = operator(x[i, 0, ...].detach())
return y
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
"""backward pass of the forward pass
Parameters
----------
ctx : context object
that can be used to obtain information from the forward pass
grad_output : torch.Tensor
mini batch of dimension (batch_size, operator.out_shape)
Returns
-------
torch.Tensor, None
mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
"""
#For details on how to implement the backward pass, see
#https://pytorch.org/docs/stable/notes/extending.html#how-to-use
# since forward takes two input arguments (x, operator)
# we have to return two arguments (the latter is None)
if grad_output is None:
return None, None
else:
operator = ctx.operator
batch_size = grad_output.shape[0]
x = torch.zeros((batch_size, 1) + operator.in_shape,
dtype=grad_output.dtype,
device=device(grad_output))
# loop over all samples in the batch and apply linear operator
# to the first channel
for i in range(batch_size):
x[i, 0, ...] = operator.adjoint(grad_output[i, ...].detach())
return x, None
class AdjointLinearSingleChannelOperator(torch.autograd.Function):
"""
Function representing the adjoint of a linear operator acting on a mini batch of single channel images
"""
@staticmethod
def forward(ctx, x: torch.Tensor,
operator: parallelproj.LinearOperator) -> torch.Tensor:
"""forward pass of the adjoint of the linear operator
Parameters
----------
ctx : context object
that can be used to store information for the backward pass
x : torch.Tensor
mini batch of 3D images with dimension (batch_size, 1, operator.out_shape)
operator : parallelproj.LinearOperator
linear operator that can act on a single 3D image
Returns
-------
torch.Tensor
mini batch of 3D images with dimension (batch_size, 1, opertor.in_shape)
"""
ctx.set_materialize_grads(False)
ctx.operator = operator
batch_size = x.shape[0]
y = torch.zeros((batch_size, 1) + operator.in_shape,
dtype=x.dtype,
device=device(x))
# loop over all samples in the batch and apply linear operator
# to the first channel
for i in range(batch_size):
y[i, 0, ...] = operator.adjoint(x[i, ...].detach())
return y
@staticmethod
def backward(ctx, grad_output):
"""backward pass of the forward pass
Parameters
----------
ctx : context object
that can be used to obtain information from the forward pass
grad_output : torch.Tensor
mini batch of dimension (batch_size, 1, operator.in_shape)
Returns
-------
torch.Tensor, None
mini batch of 3D images with dimension (batch_size, 1, opertor.out_shape)
"""
#For details on how to implement the backward pass, see
#https://pytorch.org/docs/stable/notes/extending.html#how-to-use
# since forward takes two input arguments (x, operator)
# we have to return two arguments (the latter is None)
if grad_output is None:
return None, None
else:
operator = ctx.operator
batch_size = grad_output.shape[0]
x = torch.zeros((batch_size, ) + operator.out_shape,
dtype=grad_output.dtype,
device=device(grad_output))
# loop over all samples in the batch and apply linear operator
# to the first channel
for i in range(batch_size):
x[i, ...] = operator(grad_output[i, 0, ...].detach())
return x, None
class EMUpdateModule(torch.nn.Module):
def __init__(
self,
projector: parallelproj.LinearOperator,
) -> None:
super().__init__()
self._projector = projector
self._fwd_op_layer = LinearSingleChannelOperator.apply
self._adjoint_op_layer = AdjointLinearSingleChannelOperator.apply
def forward(self, x: torch.Tensor, data: torch.Tensor,
corrections: torch.Tensor, contamination: torch.Tensor,
adjoint_ones: torch.Tensor) -> torch.Tensor:
"""forward pass of the EM update module
Parameters
----------
x : torch.Tensor
mini batch of images with shape (batch_size, 1, *img_shape)
data : torch.Tensor
mini batch of emission data with shape (batch_size, *data_shape)
corrections : torch.Tensor
mini batch of multiplicative corrections with shape (batch_size, *data_shape)
contamination : torch.Tensor
mini batch of additive contamination with shape (batch_size, *data_shape)
adjoint_ones : torch.Tensor
mini batch of adjoint ones (back projection of multiplicative corrections) with shape (batch_size, 1, *img_shape)
Returns
-------
torch.Tensor
mini batch of EM updates with shape (batch_size, 1, *img_shape)
"""
# remember that all variables contain a mini batch of images / data arrays
# and that the fwd / adjoint operator layers directly operate on mini batches
y = data / (corrections * self._fwd_op_layer(x, self._projector) +
contamination)
return x * self._adjoint_op_layer(corrections * y,
self._projector) / adjoint_ones