diff --git a/docs/api/optimizers.rst b/docs/api/optimizers.rst index adab7372..2565a19a 100644 --- a/docs/api/optimizers.rst +++ b/docs/api/optimizers.rst @@ -6,6 +6,7 @@ Optimizers .. autosummary:: adabelief adadelta + adan adafactor adagrad adam @@ -40,6 +41,10 @@ AdaDelta ~~~~~~~~~ .. autofunction:: adadelta +Adan +~~~~ +.. autofunction:: adan + AdaGrad ~~~~~~~ .. autofunction:: adagrad diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index 70733ea2..9d75ea21 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -41,6 +41,8 @@ Transformations ScaleState scale_by_adadelta ScaleByAdaDeltaState + scale_by_adan + ScaleByAdanState scale_by_adam scale_by_adamax ScaleByAdamState @@ -172,6 +174,9 @@ Transformations and states .. autofunction:: scale_by_adadelta .. autoclass:: ScaleByAdaDeltaState +.. autofunction:: scale_by_adan +.. autoclass:: ScaleByAdanState + .. autofunction:: scale_by_adam .. autofunction:: scale_by_adamax .. autoclass:: ScaleByAdamState diff --git a/optax/__init__.py b/optax/__init__.py index 52197224..d7467218 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -27,6 +27,7 @@ from optax import tree_utils from optax._src.alias import adabelief from optax._src.alias import adadelta +from optax._src.alias import adan from optax._src.alias import adafactor from optax._src.alias import adagrad from optax._src.alias import adam @@ -115,6 +116,7 @@ from optax._src.transform import normalize_by_update_norm from optax._src.transform import scale from optax._src.transform import scale_by_adadelta +from optax._src.transform import scale_by_adan from optax._src.transform import scale_by_adam from optax._src.transform import scale_by_adamax from optax._src.transform import scale_by_amsgrad @@ -139,6 +141,7 @@ from optax._src.transform import scale_by_trust_ratio from optax._src.transform import scale_by_yogi from optax._src.transform import ScaleByAdaDeltaState +from optax._src.transform import ScaleByAdanState from optax._src.transform import ScaleByAdamState from optax._src.transform import ScaleByAmsgradState from optax._src.transform import ScaleByBeliefState @@ -288,6 +291,7 @@ __all__ = ( "adabelief", "adadelta", + "adan", "adafactor", "adagrad", "adam", @@ -394,6 +398,7 @@ "safe_root_mean_squares", "ScalarOrSchedule", "scale_by_adadelta", + "scale_by_adan", "scale_by_adam", "scale_by_adamax", "scale_by_amsgrad", @@ -420,6 +425,7 @@ "scale_gradient", "scale", "ScaleByAdaDeltaState", + "ScaleByAdanState", "ScaleByAdamState", "ScaleByAmsgradState", "ScaleByBacktrackingLinesearchState", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 072766b8..a54faeb0 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -177,6 +177,80 @@ def adadelta( ) +def adan( + learning_rate: base.ScalarOrSchedule, + b1: float = 0.98, + b2: float = 0.92, + b3: float = 0.99, + eps: float = 1e-8, + eps_root: float = 1e-8, + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, +) -> base.GradientTransformation: + """The ADAptive Nesterov momentum algorithm (Adan). + + Adan is an Adam variant with Nesterov Momentum Estimation (NME). + Adan adopts NME to estimate the first- and second-order moments of + the gradient in adaptive gradient algorithms for convergence acceleration + + References: + Xie et al, 2022: https://arxiv.org/abs/2208.06677 + + Args: + learning_rate: this is a fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of gradients. + b2: Decay rate for the exponentially weighted average of difference of + gradients. + b3: Decay rate for the exponentially weighted average of the squared term. + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. + weight_decay: strength of the weight decay regularization. + mask: a tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Adam gradient transformations are applied to all parameters. + + Returns: + the corresponding `GradientTransformation`. + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> f = lambda x: x @ x # simple quadratic function + >>> solver = optax.adan(learning_rate=1e-1) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.28E+01 + Objective function: 1.17E+01 + Objective function: 1.07E+01 + Objective function: 9.69E+00 + Objective function: 8.77E+00 + + """ + return combine.chain( + transform.scale_by_adan( + b1=b1, + b2=b2, + b3=b3, + eps=eps, + eps_root=eps_root, + ), + transform.add_decayed_weights(weight_decay, mask), + transform.scale_by_learning_rate(learning_rate), + ) + + def adafactor( learning_rate: Optional[base.ScalarOrSchedule] = None, min_dim_size_to_factor: int = 128, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index c34744bb..4496deb7 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -48,6 +48,7 @@ _OPTIMIZERS_UNDER_TEST = ( dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), dict(opt_name='adadelta', opt_kwargs=dict(learning_rate=0.1)), + dict(opt_name='adan', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)), dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)), @@ -131,6 +132,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): 'lion', 'rprop', 'adadelta', + 'adan', 'polyak_sgd', 'sign_sgd', ) and jnp.iscomplexobj(dtype): diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 8a239176..32115335 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -594,6 +594,91 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) +class ScaleByAdanState(NamedTuple): + count: chex.Array + mu: base.Updates + nu: base.Updates + delta: base.Updates + grad_tm1: base.Updates + + +def scale_by_adan( + b1: float = 0.98, + b2: float = 0.92, + b3: float = 0.99, + eps: float = 1e-8, + eps_root: float = 0.0, +) -> base.GradientTransformation: + """Rescale updates according to the Adan algorithm. + + References: + [Xie et al, 2022](https://arxiv.org/abs/2208.06677) + + Args: + b1: Decay rate for the exponentially weighted average of gradients. + b2: Decay rate for the exponentially weighted average of difference of + gradients. + b3: Decay rate for the exponentially weighted average of the squared term. + eps: term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + + Returns: + An (init_fn, update_fn) tuple. + """ + def init_fn(params): + mu = otu.tree_zeros_like(params) # 1st moment + nu = otu.tree_zeros_like(params) # 2nd moment + delta = otu.tree_zeros_like(params) # EWA of difference of gradients + grad_tm1 = otu.tree_zeros_like(params) # previous gradient + return ScaleByAdanState( + count=jnp.zeros((), jnp.int32), + mu=mu, + nu=nu, + delta=delta, + grad_tm1=grad_tm1, + ) + + def update_fn(updates, state, params=None): + del params + + diff = otu.tree_where( + state.count != 0, + otu.tree_sub(updates, state.grad_tm1), + otu.tree_zeros_like(updates) + ) + + grad_prime = otu.tree_add_scalar_mul(updates, b2, diff) + + mu = otu.tree_update_moment(updates, state.mu, b1, 1) + delta = otu.tree_update_moment(diff, state.delta, b2, 1) + nu = otu.tree_update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) + + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + delta_hat = otu.tree_bias_correction(delta, b2, count_inc) + nu_hat = otu.tree_bias_correction(nu, b3, count_inc) + + new_updates = jax.tree_map( + lambda m, d, n: (m + b2 * d) / (jnp.sqrt(n + eps_root) + eps), + mu_hat, + delta_hat, + nu_hat, + ) + + new_state = ScaleByAdanState( + count=count_inc, + mu=mu, + nu=nu, + delta=delta, + grad_tm1=updates, + ) + + return new_updates, new_state + + return base.GradientTransformation(init_fn, update_fn) + + class ScaleByBeliefState(NamedTuple): """State for the rescaling by AdaBelief algorithm.""" count: chex.Array # shape=(), dtype=jnp.int32. diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index d8537df4..3e90eeae 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -43,6 +43,7 @@ def setUp(self): @chex.all_variants @parameterized.named_parameters([ ('adadelta', transform.scale_by_adadelta), + ('adan', transform.scale_by_adan), ('adam', transform.scale_by_adam), ('adamax', transform.scale_by_adamax), ('lion', transform.scale_by_lion), diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index f044d7a1..d3ea89a0 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -101,6 +101,7 @@ dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adan', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)),