Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Adan optimizer. #1090

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Optimizers
.. autosummary::
adabelief
adadelta
adan
adafactor
adagrad
adam
Expand Down Expand Up @@ -40,6 +41,10 @@ AdaDelta
~~~~~~~~~
.. autofunction:: adadelta

Adan
~~~~
.. autofunction:: adan

AdaGrad
~~~~~~~
.. autofunction:: adagrad
Expand Down
5 changes: 5 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Transformations
ScaleState
scale_by_adadelta
ScaleByAdaDeltaState
scale_by_adan
ScaleByAdanState
scale_by_adam
scale_by_adamax
ScaleByAdamState
Expand Down Expand Up @@ -172,6 +174,9 @@ Transformations and states
.. autofunction:: scale_by_adadelta
.. autoclass:: ScaleByAdaDeltaState

.. autofunction:: scale_by_adan
.. autoclass:: ScaleByAdanState

.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autoclass:: ScaleByAdamState
Expand Down
6 changes: 6 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optax import tree_utils
from optax._src.alias import adabelief
from optax._src.alias import adadelta
from optax._src.alias import adan
from optax._src.alias import adafactor
from optax._src.alias import adagrad
from optax._src.alias import adam
Expand Down Expand Up @@ -115,6 +116,7 @@
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adan
from optax._src.transform import scale_by_adam
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_amsgrad
Expand All @@ -139,6 +141,7 @@
from optax._src.transform import scale_by_trust_ratio
from optax._src.transform import scale_by_yogi
from optax._src.transform import ScaleByAdaDeltaState
from optax._src.transform import ScaleByAdanState
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
Expand Down Expand Up @@ -288,6 +291,7 @@
__all__ = (
"adabelief",
"adadelta",
"adan",
"adafactor",
"adagrad",
"adam",
Expand Down Expand Up @@ -394,6 +398,7 @@
"safe_root_mean_squares",
"ScalarOrSchedule",
"scale_by_adadelta",
"scale_by_adan",
"scale_by_adam",
"scale_by_adamax",
"scale_by_amsgrad",
Expand All @@ -420,6 +425,7 @@
"scale_gradient",
"scale",
"ScaleByAdaDeltaState",
"ScaleByAdanState",
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBacktrackingLinesearchState",
Expand Down
119 changes: 119 additions & 0 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,125 @@ def adadelta(
)


def adan(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.98,
b2: float = 0.92,
b3: float = 0.99,
eps: float = 1e-8,
eps_root: float = 1e-8,
weight_decay: float = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
r"""The ADAptive Nesterov momentum algorithm (Adan).

Adan first reformulates the vanilla Nesterov acceleration to develop a new
Nesterov momentum estimation (NME) method, which avoids the extra overhead of
computing gradient at the extrapolation point. Then Adan adopts NME to
estimate the gradient's first- and second-order moments in adaptive gradient
algorithms for convergence acceleration.

The algorithm is as follows. First, we define the following parameters:

- :math:`\eta > 0`: the step size.
- :math:`\beta_1 \in [0, 1]`: the decay rate for the exponentially weighted
average of gradients.
- :math:`\beta_2 \in [0, 1]`: the decay rate for the exponentially weighted
average of differences of gradients.
- :math:`\beta_3 \in [0, 1]`: the decay rate for the exponentially weighted
average of the squared term.
- :math:`\varepsilon > 0`: a small constant for numerical stability.
- :math:`\lambda > 0`: a weight decay.

Second, we define the following variables:

- :math:`\theta_t`: the parameters.
- :math:`g_t`: the incoming stochastic gradient.
- :math:`m_t`: the exponentially weighted average of gradients.
- :math:`v_t`: the exponentially weighted average of differences of gradients.
- :math:`n_t`: the exponentially weighted average of the squared term.
- :math:`u_t`: the outgoing update vector.
- :math:`S_t`: the saved state of the optimizer.

Third, we initialize these variables as follows:

- :math:`m_0 = g_0`
- :math:`v_0 = 0`
- :math:`v_1 = g_1 - g_0`
- :math:`n_0 = g_0^2`

Finally, on each iteration, we update the variables as follows:

.. math::

\begin{align*}
m_t &\gets (1 - \beta_1) m_{t-1} + \beta_1 g_t \\
v_t &\gets (1 - \beta_2) v_{t-1} + \beta_2 (g_t - g_{t-1}) \\
n_t &\gets (1 - \beta_3) n_{t-1} + \beta_3 (g_t + (1 - \beta_2)
(g_t - g_{t-1}))^2 \\
\eta_t &\gets \eta / ({\sqrt{n_t + \bar{\varepsilon}} + \varepsilon}) \\
u_t &\gets (\theta_t - \eta_t \circ (m_t + (1 - \beta_2) v_t))
/ (1 + \lambda \eta) \\
S_t &\leftarrow (m_t, v_t, n_t).
\end{align*}

References:
Xie et al, `Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
Deep Models
<https://arxiv.org/abs/2208.06677>`_, 2022

carlosgmartin marked this conversation as resolved.
Show resolved Hide resolved
Args:
learning_rate: this is a fixed global scaling factor.
b1: Decay rate for the EWMA of gradients.
b2: Decay rate for the EWMA of differences of gradients.
b3: Decay rate for the EMWA of the algorithm's squared term.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
weight_decay: Strength of the weight decay regularization.
mask: A tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip.

Returns:
the corresponding :class:`optax.GradientTransformation`.

Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: x @ x # simple quadratic function
>>> solver = optax.adan(learning_rate=1e-1)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.28E+01
Objective function: 1.17E+01
Objective function: 1.07E+01
Objective function: 9.68E+00
Objective function: 8.76E+00

"""
return combine.chain(
transform.scale_by_adan(
b1=b1,
b2=b2,
b3=b3,
eps=eps,
eps_root=eps_root,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)


def adafactor(
learning_rate: Optional[base.ScalarOrSchedule] = None,
min_dim_size_to_factor: int = 128,
Expand Down
2 changes: 2 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_OPTIMIZERS_UNDER_TEST = (
dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)),
dict(opt_name='adadelta', opt_kwargs=dict(learning_rate=0.1)),
dict(opt_name='adan', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)),
Expand Down Expand Up @@ -131,6 +132,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype):
'lion',
'rprop',
'adadelta',
'adan',
'polyak_sgd',
'sign_sgd',
) and jnp.iscomplexobj(dtype):
Expand Down
80 changes: 80 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,86 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByAdanState(NamedTuple):
m: base.Updates
v: base.Updates
n: base.Updates
g: base.Updates
t: chex.Array


def scale_by_adan(
b1: float = 0.98,
b2: float = 0.92,
b3: float = 0.99,
eps: float = 1e-8,
eps_root: float = 0.0,
) -> base.GradientTransformation:
"""Rescale updates according to the Adan algorithm.

References:
Xie et al, `Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
Deep Models
<https://arxiv.org/abs/1412.6980>`_, 2022

Args:
b1: Decay rate for the EWMA of gradients.
b2: Decay rate for the EWMA of differences of gradients.
b3: Decay rate for the EMWA of the algorithm's squared term.
eps: Term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.

Returns:
An :class:`optax.GradientTransformation` object.
"""
def init_fn(params):
return ScaleByAdanState(
m=otu.tree_zeros_like(params),
v=otu.tree_zeros_like(params),
n=otu.tree_zeros_like(params),
g=otu.tree_zeros_like(params),
t=jnp.zeros([], jnp.int32),
)

def update_fn(updates, state, params=None):
carlosgmartin marked this conversation as resolved.
Show resolved Hide resolved
"""Based on Algorithm 1 in https://arxiv.org/pdf/2208.06677v4#page=6."""
del params
g = updates

diff = otu.tree_where(
state.t == 0,
otu.tree_zeros_like(g),
otu.tree_sub(g, state.g),
)
m = otu.tree_update_moment(g, state.m, b1, 1)
v = otu.tree_update_moment(diff, state.v, b2, 1)

sq = otu.tree_add_scalar_mul(g, 1 - b2, diff)
n = otu.tree_update_moment_per_elem_norm(sq, state.n, b3, 2)

t = numerics.safe_increment(state.t)
m_hat = otu.tree_bias_correction(m, b1, t)
v_hat = otu.tree_bias_correction(v, b2, t)
n_hat = otu.tree_bias_correction(n, b3, t)

u = otu.tree_add_scalar_mul(m_hat, 1 - b2, v_hat)
denom = jax.tree.map(lambda n_hat: jnp.sqrt(n_hat + eps_root) + eps, n_hat)
u = otu.tree_div(u, denom)

new_state = ScaleByAdanState(
m=m,
v=v,
n=n,
g=g,
t=t,
)

return u, new_state

return base.GradientTransformation(init_fn, update_fn)


class ScaleByBeliefState(NamedTuple):
"""State for the rescaling by AdaBelief algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def setUp(self):
@chex.all_variants
@parameterized.named_parameters([
('adadelta', transform.scale_by_adadelta),
('adan', transform.scale_by_adan),
('adam', transform.scale_by_adam),
('adamax', transform.scale_by_adamax),
('lion', transform.scale_by_lion),
Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)),
dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='adan', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)),
Expand Down
Loading