Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intended usage of the Sophia optimiser #968

Closed
vvvm23 opened this issue May 20, 2024 · 7 comments
Closed

Intended usage of the Sophia optimiser #968

vvvm23 opened this issue May 20, 2024 · 7 comments

Comments

@vvvm23
Copy link

vvvm23 commented May 20, 2024

Hi all,

I noticed in this commit that the Sophia optimiser (see paper) has been integrated into optax.

However, I noticed the estimate of the diagonal hessian is far simpler than either variant proposed in the paper.
image

The first uses a Hessian-vector product and the second uses label resampling. However, in the merged version of the code, we only seem to use the square of the current gradient. Furthermore, this comment:

update_hessian_every: How often to update the second order terms. As these
      are very cheap to compute, in fact (just squaring gradients), we can leave
      these at 1 by default. Must be >= 1.

seems to further contradict the paper, which says the update is quite expensive to compute and is limited to (usually) every ten steps. This is pretty confusing as the api of optax makes it seem that all optimisers can be used as drop-in replacements for one another, but here it seems additional work is required to get the behaviour we want.

How should I use the optax version of the optimiser in order to replicate the behaviour in the paper? Can you share example code for doing this?

@vroulet
Copy link
Collaborator

vroulet commented May 22, 2024

Thanks for catching this. You are absolutely right.
The implementation of recent optimizers in the contrib folder are the responsibility of the authors adding them.

@jbausch-gdm
Copy link
Contributor

Good catch indeed! The current version implements the special case of the GNB estimator for a squared loss only (as mentioned in the paper on p. 9) for which we can use the squared gradients to estimate the diagonal of the Hessian. I'll take a look at the other estimators as well.

@vvvm23
Copy link
Author

vvvm23 commented May 24, 2024

Using the squared gradients each step isn't too dissimilar to Adam no? In my experiments I get pretty similar convergence to Adam with the GNB estimator.

It's nice to include, but I feel it is quite confusing to users who have read the Sophia paper, and expect those nice convergence gains given the name of the function.

However, saying that, the requirement for additional model forward passes to use the estimators in the paper is difficult to combine with Optax's existing API. This repo https://github.com/stanford-crfm/levanter implements it quite well, if you are looking for inspiration.

@evanatyourservice
Copy link

evanatyourservice commented May 25, 2024

Yeah I was kind of surprised that pr was merged as well as the squared gradient isn't really up to par with the main magic of sophia. I see the pr to remove sophia, either way, I'll add a pr for my sophia-h implementation in a _sophia_h contrib file. It is very close to levanter's but has some other features. It does require the loss fn to be passed in, but optax has GradientTransformationExtraArgs which allows for that nicely. I could add a well-described error in case a loss fn isn't passed in to help people use it more easily.

@fabianp
Copy link
Member

fabianp commented May 27, 2024

hey there. Indeed, we plan to remove the current implementation of sophia after this bug report. We would of course be happy to accept contributions of implementations of this solver that don't share the flaws of the current (soon to be removed) implementation.

@evanatyourservice
Copy link

Ok, I need to make it play nicely with injected hyperparams and then I'll submit a pr

@fabianp
Copy link
Member

fabianp commented Jun 27, 2024

closing this in favor of #979

@vroulet vroulet closed this as completed Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants