Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sophia-h optimizer #979

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

evanatyourservice
Copy link

PR to add sophia optimizer. It's mostly based on levanter's implementation with some changes/added features here and there.

One note is that I had to change the contrib common test file a couple times, once to pass the loss_fn out of the parabola and rosenbrock functions (could be useful later for other optimizers that need loss function), and a second time to bypass the check for update arguments to be values (the loss function is not). Please advise if these changes are not ok or the most correct.

@evanatyourservice
Copy link
Author

fixes #968

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much @evanatyourservice! And sorry for the delay.
I left you some comments we can discuss about.

optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
optax/contrib/_sophia_h.py Outdated Show resolved Hide resolved
@evanatyourservice
Copy link
Author

Hi Vincent, thank you for the notes! They all make perfect sense to me and I'll get to updating the code/answering them tomorrow

@fabianp
Copy link
Member

fabianp commented Jun 27, 2024

@evanatyourservice please ping us whenever you're ready for another round of reviews :-)

@evanatyourservice
Copy link
Author

@fabianp Will do! Sorry been moving but will try to get this going asap

@fabianp
Copy link
Member

fabianp commented Jun 28, 2024

there's no rush, just wanted to make sure you were not waiting on us :-)

@evanatyourservice
Copy link
Author

@vroulet @fabianp Got some updates pushed, let me know if anything needs to be changed! Thanks

@vroulet
Copy link
Collaborator

vroulet commented Sep 14, 2024

Hello @evanatyourservice,
Sorry for the very long delay on our end. I think your code looks great!
If it's ok with you, can you merge the code with head once #1060 is merged (#1060 adds more tests for the contrib optimizers to ensure compatibilities). Then I should approve and finish on our side if there are still minor details to fine-tune.
Thank you again!

@evanatyourservice
Copy link
Author

sounds good!

raise ValueError("obj_fn must be provided to hutchinson update function.")
del updates
key, subkey = jax.random.split(state.key)
random_signs = otu.tree_random_like(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell from the paper (https://arxiv.org/pdf/2305.14342, section 2.3), it computes the Hutchinson estimator using a Normal distribution, while here we use a Rademacher distribution. The Rademacher distribution should be lower variance, but perhaps it's worth to add a comment in that this deviates from what is specified in the paper?

@vroulet
Copy link
Collaborator

vroulet commented Sep 23, 2024

Hello @evanatyourservice,
#1060 got merged. You may merge the tests, address Fabian's comment, and I can approve (and maybe fix minor issues on our end if there are).
Thank you again!

@evanatyourservice
Copy link
Author

Ok sounds good! Sorry I should get to this tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants