Skip to content

Commit

Permalink
Add Adan optimizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Oct 4, 2024
1 parent 789212c commit bd8c558
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Optimizers
.. autosummary::
adabelief
adadelta
adan
adafactor
adagrad
adam
Expand Down Expand Up @@ -40,6 +41,10 @@ AdaDelta
~~~~~~~~~
.. autofunction:: adadelta

Adan
~~~~
.. autofunction:: adan

AdaGrad
~~~~~~~
.. autofunction:: adagrad
Expand Down
5 changes: 5 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Transformations
ScaleState
scale_by_adadelta
ScaleByAdaDeltaState
scale_by_adan
ScaleByAdanState
scale_by_adam
scale_by_adamax
ScaleByAdamState
Expand Down Expand Up @@ -172,6 +174,9 @@ Transformations and states
.. autofunction:: scale_by_adadelta
.. autoclass:: ScaleByAdaDeltaState

.. autofunction:: scale_by_adan
.. autoclass:: ScaleByAdanState

.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autoclass:: ScaleByAdamState
Expand Down
6 changes: 6 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from optax import tree_utils
from optax._src.alias import adabelief
from optax._src.alias import adadelta
from optax._src.alias import adan
from optax._src.alias import adafactor
from optax._src.alias import adagrad
from optax._src.alias import adam
Expand Down Expand Up @@ -115,6 +116,7 @@
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adan
from optax._src.transform import scale_by_adam
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_amsgrad
Expand All @@ -139,6 +141,7 @@
from optax._src.transform import scale_by_trust_ratio
from optax._src.transform import scale_by_yogi
from optax._src.transform import ScaleByAdaDeltaState
from optax._src.transform import ScaleByAdanState
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
Expand Down Expand Up @@ -288,6 +291,7 @@
__all__ = (
"adabelief",
"adadelta",
"adan",
"adafactor",
"adagrad",
"adam",
Expand Down Expand Up @@ -394,6 +398,7 @@
"safe_root_mean_squares",
"ScalarOrSchedule",
"scale_by_adadelta",
"scale_by_adan",
"scale_by_adam",
"scale_by_adamax",
"scale_by_amsgrad",
Expand All @@ -420,6 +425,7 @@
"scale_gradient",
"scale",
"ScaleByAdaDeltaState",
"ScaleByAdanState",
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBacktrackingLinesearchState",
Expand Down
74 changes: 74 additions & 0 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,80 @@ def adadelta(
)


def adan(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.98,
b2: float = 0.92,
b3: float = 0.99,
eps: float = 1e-8,
eps_root: float = 1e-8,
weight_decay: float = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
"""The ADAptive Nesterov momentum algorithm (Adan).
Adan is an Adam variant with Nesterov Momentum Estimation (NME).
Adan adopts NME to estimate the first- and second-order moments of
the gradient in adaptive gradient algorithms for convergence acceleration
References:
Xie et al, 2022: https://arxiv.org/abs/2208.06677
Args:
learning_rate: this is a fixed global scaling factor.
b1: Decay rate for the exponentially weighted average of gradients.
b2: Decay rate for the exponentially weighted average of difference of
gradients.
b3: Decay rate for the exponentially weighted average of the squared term.
eps: a small constant applied to denominator outside of the square root
(as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root: A small constant applied to denominator inside the square root (as
in RMSProp), to avoid dividing by zero when rescaling.
weight_decay: strength of the weight decay regularization.
mask: a tree with same structure as (or a prefix of) the params PyTree,
or a Callable that returns such a pytree given the params/updates.
The leaves should be booleans, `True` for leaves/subtrees you want to
apply the weight decay to, and `False` for those you want to skip. Note
that the Adam gradient transformations are applied to all parameters.
Returns:
the corresponding `GradientTransformation`.
Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: x @ x # simple quadratic function
>>> solver = optax.adan(learning_rate=1e-1)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.28E+01
Objective function: 1.17E+01
Objective function: 1.07E+01
Objective function: 9.69E+00
Objective function: 8.77E+00
"""
return combine.chain(
transform.scale_by_adan(
b1=b1,
b2=b2,
b3=b3,
eps=eps,
eps_root=eps_root,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)


def adafactor(
learning_rate: Optional[base.ScalarOrSchedule] = None,
min_dim_size_to_factor: int = 128,
Expand Down
2 changes: 2 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_OPTIMIZERS_UNDER_TEST = (
dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)),
dict(opt_name='adadelta', opt_kwargs=dict(learning_rate=0.1)),
dict(opt_name='adan', opt_kwargs=dict(learning_rate=1e-1)),
dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)),
dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)),
Expand Down Expand Up @@ -131,6 +132,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype):
'lion',
'rprop',
'adadelta',
'adan',
'polyak_sgd',
'sign_sgd',
) and jnp.iscomplexobj(dtype):
Expand Down
85 changes: 85 additions & 0 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,91 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByAdanState(NamedTuple):
count: chex.Array
mu: base.Updates
nu: base.Updates
delta: base.Updates
grad_tm1: base.Updates


def scale_by_adan(
b1: float = 0.98,
b2: float = 0.92,
b3: float = 0.99,
eps: float = 1e-8,
eps_root: float = 0.0,
) -> base.GradientTransformation:
"""Rescale updates according to the Adan algorithm.
References:
[Xie et al, 2022](https://arxiv.org/abs/2208.06677)
Args:
b1: Decay rate for the exponentially weighted average of gradients.
b2: Decay rate for the exponentially weighted average of difference of
gradients.
b3: Decay rate for the exponentially weighted average of the squared term.
eps: term added to the denominator to improve numerical stability.
eps_root: Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(params):
mu = otu.tree_zeros_like(params) # 1st moment
nu = otu.tree_zeros_like(params) # 2nd moment
delta = otu.tree_zeros_like(params) # EWA of difference of gradients
grad_tm1 = otu.tree_zeros_like(params) # previous gradient
return ScaleByAdanState(
count=jnp.zeros((), jnp.int32),
mu=mu,
nu=nu,
delta=delta,
grad_tm1=grad_tm1,
)

def update_fn(updates, state, params=None):
del params

diff = otu.tree_where(
state.count != 0,
otu.tree_sub(updates, state.grad_tm1),
otu.tree_zeros_like(updates)
)

grad_prime = otu.tree_add_scalar_mul(updates, b2, diff)

mu = otu.tree_update_moment(updates, state.mu, b1, 1)
delta = otu.tree_update_moment(diff, state.delta, b2, 1)
nu = otu.tree_update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)

count_inc = numerics.safe_int32_increment(state.count)
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
delta_hat = otu.tree_bias_correction(delta, b2, count_inc)
nu_hat = otu.tree_bias_correction(nu, b3, count_inc)

new_updates = jax.tree_map(
lambda m, d, n: (m + b2 * d) / (jnp.sqrt(n + eps_root) + eps),
mu_hat,
delta_hat,
nu_hat,
)

new_state = ScaleByAdanState(
count=count_inc,
mu=mu,
nu=nu,
delta=delta,
grad_tm1=updates,
)

return new_updates, new_state

return base.GradientTransformation(init_fn, update_fn)


class ScaleByBeliefState(NamedTuple):
"""State for the rescaling by AdaBelief algorithm."""
count: chex.Array # shape=(), dtype=jnp.int32.
Expand Down
1 change: 1 addition & 0 deletions optax/_src/transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def setUp(self):
@chex.all_variants
@parameterized.named_parameters([
('adadelta', transform.scale_by_adadelta),
('adan', transform.scale_by_adan),
('adam', transform.scale_by_adam),
('adamax', transform.scale_by_adamax),
('lion', transform.scale_by_lion),
Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)),
dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='adan', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)),
dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)),
Expand Down

0 comments on commit bd8c558

Please sign in to comment.