Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spdhg with stochastic sampler #1644

Open
wants to merge 139 commits into
base: master
Choose a base branch
from

Conversation

MargaretDuff
Copy link
Member

@MargaretDuff MargaretDuff commented Jan 10, 2024

Describe your changes

  • Allow SPDHG to take a sampler either from our sampler class or any class with a next(self) function
  • Deprecated prob from SPDHG, taking the probabilities instead from the sampler class or from a new argument prob_weights, choosing the default [1/num_subsets]*num_subsets if one is not provided in either place.
  • Created two setters for the step sizes. set_step_sizes_from_ratio resets the step sizes if the user provides one/both/none of gamma and rho - note that this closes SPDHG gamma parameter is applied incorrectly  #1860. step_sizes_custom takes in one/both/none of sigma and tau allowing the user to use a custom sigma and tau with those not provided calculated from the defaults. Calculating sigma from tau probably needs checking with someone else.
  • Added a check_convergence function that checks self._sigma[i] * self._tau * self.norms[i]**2 <= self.prob_weights[i] for all i. This probably needs checking with someone else.
  • Deprecated the kwarg "norms" to be replaced by the set_norms method in BlockOperator: added a function to return a list of norms and the ability to set this list of norms BlockOperator: added a function to return a list of norms and the ability to set this list of norms BlockOperator: added a function to return a list of norms and the ability to set this list of norms  #1513.
  • Unit tests for SPDHG setters and convergence check
  • Fixes BlockOperator.domain_geometry().allocate() not compatible with in place calls to BlockOperator.direct  #1863

Describe any testing you have performed

Please add any demo scripts to CIL-Demos/misc/
Test with SPDHG https://github.com/TomographicImaging/CIL-Demos/blob/main/misc/testing_sampling_SPDHG.ipynb

Similar results gained for all samplers for SPDHG, with 10 subsets
image

With 80 subsets:
image

Link relevant issues

Part of the stochastic work plan. Closes #1575. Closes #1576. Closes #1500. Closes #1496

Checklist when you are ready to request a review

  • I have performed a self-review of my code
  • I have added docstrings in line with the guidance in the developer guide
  • I have implemented unit tests that cover any new or modified functionality
  • CHANGELOG.md has been updated with any functionality change
  • Request review from all relevant developers
  • Change pull request label to 'Waiting for review'

Contribution Notes

Please read and adhere to the developer guide and local patterns and conventions.

  • The content of this Pull Request (the Contribution) is intentionally submitted for inclusion in CIL (the Work) under the terms and conditions of the Apache-2.0 License.
  • I confirm that the contribution does not violate any intellectual property rights of third parties

MargaretDuff and others added 30 commits August 2, 2023 10:36
Quick docstring

Signed-off-by: Margaret Duff <43645617+MargaretDuff@users.noreply.github.com>
Signed-off-by: Margaret Duff <43645617+MargaretDuff@users.noreply.github.com>
@lauramurgatroyd lauramurgatroyd removed this from the v24.2.0 milestone Sep 24, 2024
sampler: optional, an instance of a `cil.optimisation.utilities.Sampler` class or another class with the function __next__(self) implemented outputting an integer from {1,...,len(operator)}.
Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have probability = 1/len(operator)
prob_weights: optional, list of floats of length num_indices that sum to 1. Defaults to [1/len(operator)]*len(operator)
Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute.
Copy link
Contributor

@paskino paskino Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we explain "Note that this should not be passed if the provided sampler has it as an attribute."?

Maybe:
Note: if the sampler has a prob_weight attribute it will take precedence on this parameter.

Comment on lines 369 to 370
else:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to return a value error instead. Basically we can't check it for non-scalar values of tau

Comment on lines +378 to +379
self._zbar.sapyb(self._tau, self.x, -1., out=self._x_tmp )
self._x_tmp*=-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self._tau is a number I don't see the reason of this change as it forces you to have an additional loop

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but i think your point was that if self.tau was an array then changing these two lines means that you don't allocate memory doing -self.tau

Comment on lines 154 to 173
self._sampler = sampler

self._prob_weights = getattr(self._sampler, 'prob_weights', None)
if prob_weights is not None:
if self._prob_weights is None:
self._prob_weights = prob_weights
else:
raise ValueError(
' You passed a `prob_weights` argument and a sampler with attribute `prob_weights`, please remove the `prob_weights` argument.')

self._deprecated_kwargs(deprecated_kwargs)

if self._prob_weights is None:
self._prob_weights = [1/self._ndual_subsets]*self._ndual_subsets

if self._sampler is None:
self._sampler = Sampler.random_with_replacement(
len(operator), prob=self._prob_weights)

self._norms = operator.get_norms_as_list()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify this part.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Edo, I had a go and tried to explain the reasoning in the comments.

Copy link
Member

@gfardell gfardell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly docstring clarification. I think it's very close.

Comment on lines 48 to 49
tau : positive float, optional, default=None
Step size parameter for Primal problem
Step size parameter for primal problem. If `None` see note.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove default=None I think it's fine to just say optional (we use that elsewhere in CIL).

The description should say it'll be computed.

How about:

tau : positive float, optional
        Step size parameter for the primal problem. If `None` will be computed by algorithm, see note for details.`

The same comment applies to all the arguments in the docstring.

Comment on lines 56 to 57
sampler: optional, an instance of a `cil.optimisation.utilities.Sampler` class or another class with the function __next__(self) implemented outputting an integer from {1,...,len(operator)}.
Method of selecting the next index for the SPDHG update. If None, a sampler will be created for random sampling with replacement and each index will have `probability = 1/len(operator)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably needs to be clearer.

sampler: cil.optimisation.utilities.Sampler, optional
   A `Sampler` controllingthe selection of the next index for the SPDHG update. If `None`, a sampler will be created for uniform random sampling with replacement. See notes.

Note
-----
The `sampler` can be an instance of the `cil.optimisation.utilities.Sampler` class or a custom class with the `__next__(self)` method implemented, which outputs an integer index from {1, ..., len(operator)}. 

Note
-----
"Random sampling with replacement" will select the next index with equal probability from  `1 - len(operator)`.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

parameter controlling the trade-off between the primal and dual step sizes
gamma : float, optional
Parameter controlling the trade-off between the primal and dual step sizes
sampler: optional, an instance of a `cil.optimisation.utilities.Sampler` class or another class with the function __next__(self) implemented outputting an integer from {1,...,len(operator)}.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really from {1,...,len(operator)}?

We should use zero indexing everywhere so this might be a typo or a bigger issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot, we do index from 0 so a typo and not a bigger issue

Comment on lines 58 to 59
prob_weights: optional, list of floats of length `num_indices` that sum to 1. Defaults to `[1/len(operator)]*len(operator)`
Consider that the sampler is called a large number of times this argument holds the expected number of times each index would be called, normalised to 1. Note that this should not be passed if the provided sampler has it as an attribute: if the sampler has a `prob_weight` attribute it will take precedence on this parameter.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the input type needs to be concise. list of floats, optional with the description expanding.

Beyond that, it's not clear why we need this and where it's used. ISn't this what sampler now controls? So either it's a docstring or implementation problem. Should it be moved to kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a design decision made that the sampler doesn't have to have prob_weights as in other stochastic algorithms they are not essential, just for reporting and plotting. However, this algorithm requires them to set sigma and tau. As the sampler might not have that argument, instead it can be passed seperately to the algorithm. But maybe we chat about it being a kwarg...



# Set up sampler and prob weights from deprecated "prob" argument
sampler = self._deprecated_set_prob(deprecated_kwargs, prob_weights, sampler)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't very clear. I can see if I pass sampler=None and prob a sampler is created. Could this be handled directly by the if statement in ln 167 that creates a sampler?

Otherwise the naming implies it's just about setting prob, which I think would be sufficient. Then if Sampler is None it gets created as planned with self._prob_weights

I can see you're trying to split the deprecated set up from the main code, but it's confusing that it's creating the sampler in 2 places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sorted

self._deprecated_set_norms(deprecated_kwargs)
self._norms = operator.get_norms_as_list()
#Check for other kwargs
self._deprecated_else(deprecated_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is worth having as a function, just a check on unused kwargs would maybe suffice if it's needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Status: Needs reviewing
Status: PRs to review
4 participants