From 7e7fa3d8997a1d8d029a2fff6e406c35662332b5 Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Mon, 18 Sep 2023 15:41:55 +0000 Subject: [PATCH] Formatting issues --- .../cil/optimisation/operators/Operator.py | 275 +++++++++--------- 1 file changed, 138 insertions(+), 137 deletions(-) diff --git a/Wrappers/Python/cil/optimisation/operators/Operator.py b/Wrappers/Python/cil/optimisation/operators/Operator.py index 8944441809..05a2d1fc46 100644 --- a/Wrappers/Python/cil/optimisation/operators/Operator.py +++ b/Wrappers/Python/cil/optimisation/operators/Operator.py @@ -22,6 +22,7 @@ import functools import warnings + class Operator(object): """ Operator that maps from a space X -> Y @@ -45,11 +46,11 @@ def __init__(self, domain_geometry, **kwargs): def is_linear(self): '''Returns if the operator is linear''' return False - def direct(self,x, out=None): + + def direct(self, x, out=None): '''Returns the application of the Operator on x''' raise NotImplementedError - def norm(self, **kwargs): '''Returns the norm of the Operator. On first call the norm will be calculated using the operator's calculate_norm method. Subsequent calls will return the cached norm. @@ -68,7 +69,7 @@ def norm(self, **kwargs): return self._norm - def set_norm(self,norm=None): + def set_norm(self, norm=None): '''Sets the norm of the operator to a custom value. ''' self._norm = norm @@ -76,44 +77,49 @@ def set_norm(self,norm=None): def calculate_norm(self): '''Calculates the norm of the Operator''' raise NotImplementedError + def range_geometry(self): '''Returns the range of the Operator: Y space''' return self._range_geometry + def domain_geometry(self): '''Returns the domain of the Operator: X space''' return self._domain_geometry + @property def domain(self): return self.domain_geometry() + @property def range(self): return self.range_geometry() + def __rmul__(self, scalar): '''Defines the multiplication by a scalar on the left returns a ScaledOperator''' return ScaledOperator(self, scalar) - + def compose(self, *other, **kwargs): - # TODO: check equality of domain and range of operators - #if self.operator2.range_geometry != self.operator1.domain_geometry: - # raise ValueError('Cannot compose operators, check domain geometry of {} and range geometry of {}'.format(self.operato1,self.operator2)) - - return CompositionOperator(self, *other, **kwargs) + # TODO: check equality of domain and range of operators + # if self.operator2.range_geometry != self.operator1.domain_geometry: + # raise ValueError('Cannot compose operators, check domain geometry of {} and range geometry of {}'.format(self.operato1,self.operator2)) + + return CompositionOperator(self, *other, **kwargs) def __add__(self, other): return SumOperator(self, other) def __mul__(self, scalar): - return self.__rmul__(scalar) - + return self.__rmul__(scalar) + def __neg__(self): """ Return -self """ - return -1 * self - + return -1 * self + def __sub__(self, other): """ Returns the subtraction of the operators.""" - return self + (-1) * other + return self + (-1) * other class LinearOperator(Operator): @@ -129,28 +135,30 @@ class LinearOperator(Operator): range_geometry : ImageGeometry or AcquisitionGeometry, optional, default None range of the operator """ + def __init__(self, domain_geometry, **kwargs): super(LinearOperator, self).__init__(domain_geometry, **kwargs) + def is_linear(self): '''Returns if the operator is linear''' return True - def adjoint(self,x, out=None): + + def adjoint(self, x, out=None): '''returns the adjoint/inverse operation - + only available to linear operators''' raise NotImplementedError - - @staticmethod - def PowerMethod(operator, max_iteration=10, initial=None, tolerance=1e-5, return_all=False,range_is_domain=None ): + @staticmethod + def PowerMethod(operator, max_iteration=10, initial=None, tolerance=1e-5, return_all=False, range_is_domain=None): r"""Power method or Power iteration algorithm - + The Power method computes the largest (dominant) eigenvalue of a square matrix in magnitude, e.g., absolute value in the real case and modulus in the complex case. If the matrix is not square. The algorithm computes the largest (dominant) eigenvalue of :math: A^{T}*A :math:, returning the square root of this value. - - + + Parameters ---------- @@ -161,12 +169,12 @@ def PowerMethod(operator, max_iteration=10, initial=None, tolerance=1e-5, retur Starting point for the Power method. tolerance: positive:`float`, default = 1e-5 Stopping criterion for the Power method. Check if two consecutive eigenvalue evaluations are below the tolerance. - return_all: boolean, default = False + return_all: `boolean`, default = False Toggles the verbosity of the return range_is_domain: `boolean`, default None Set this to `True` to apply the power method directly on the operator, :math: A :math:, and `False` to apply the power method to :math:A^TA:math: before taking the square root of the result. Leave as default `None` to determine this from domain and range geometry. - + Returns ------- @@ -195,18 +203,18 @@ def PowerMethod(operator, max_iteration=10, initial=None, tolerance=1e-5, retur 2.0005647295658866 """ - convergence_check=True + convergence_check = True if range_is_domain is None: square = False try: - if operator.domain_geometry()==operator.range_geometry(): + if operator.domain_geometry() == operator.range_geometry(): square = True except AssertionError: # catch AssertionError for SIRF objects https://github.com/SyneRBI/SIRF-SuperBuild/runs/5110228626?check_suite_focus=true#step:8:972 pass else: - square=range_is_domain - + square = range_is_domain + if initial is None: x0 = operator.domain_geometry().allocate('random') else: @@ -225,60 +233,52 @@ def PowerMethod(operator, max_iteration=10, initial=None, tolerance=1e-5, retur diff = numpy.finfo('d').max i = 0 while (i < max_iteration and diff > tolerance): - - operator.direct(x0, out = y_tmp) - - if square: - #swap datacontainer references + operator.direct(x0, out=y_tmp) + + if square: + # swap datacontainer references tmp = x0 x0 = y_tmp y_tmp = tmp else: - operator.adjoint(y_tmp,out=x0) - - + operator.adjoint(y_tmp, out=x0) + # Get eigenvalue using Rayleigh quotient: denominator=1, due to normalization - x0_norm=x0.norm() - if x0_norm