-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sophia-h optimizer #979
base: main
Are you sure you want to change the base?
Add sophia-h optimizer #979
Conversation
fixes #968 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much @evanatyourservice! And sorry for the delay.
I left you some comments we can discuss about.
Hi Vincent, thank you for the notes! They all make perfect sense to me and I'll get to updating the code/answering them tomorrow |
@evanatyourservice please ping us whenever you're ready for another round of reviews :-) |
@fabianp Will do! Sorry been moving but will try to get this going asap |
there's no rush, just wanted to make sure you were not waiting on us :-) |
Hello @evanatyourservice, |
sounds good! |
raise ValueError("obj_fn must be provided to hutchinson update function.") | ||
del updates | ||
key, subkey = jax.random.split(state.key) | ||
random_signs = otu.tree_random_like( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell from the paper (https://arxiv.org/pdf/2305.14342, section 2.3), it computes the Hutchinson estimator using a Normal distribution, while here we use a Rademacher distribution. The Rademacher distribution should be lower variance, but perhaps it's worth to add a comment in that this deviates from what is specified in the paper?
Hello @evanatyourservice, |
Ok sounds good! Sorry I should get to this tomorrow |
PR to add sophia optimizer. It's mostly based on levanter's implementation with some changes/added features here and there.
One note is that I had to change the contrib common test file a couple times, once to pass the loss_fn out of the parabola and rosenbrock functions (could be useful later for other optimizers that need loss function), and a second time to bypass the check for update arguments to be values (the loss function is not). Please advise if these changes are not ok or the most correct.