Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of reading learning rate from optimizer state #363

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

atgctg
Copy link
Contributor

@atgctg atgctg commented Jun 18, 2022

Closes #312.

I extended the Optax 101 notebook to also include an example of extracting the learning rate, based on #206.

Open In Colab

This example uses the inject_hyperparams wrapper.

Some additional changes include:

  • better formatting of print outputs
  • extracting the step() function, so it can reused later on
  • made param generation deterministic, so outputs stay the same

I have also tried to do it by passing in the count from inner_state to the scheduler. I found this method quite hacky due to having to index into inner_state[0], so I have not included the following way of doing it:

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      # not clear what's going on here
      count = opt_state.inner_state[0].count
      lr = schedule(count)
      print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.9f}')

  return params

params = fit(initial_params, optimizer)

Please let me know what you think. Happy to make any changes!

@google-cla
Copy link

google-cla bot commented Jun 18, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

Huge apologies for the long delay here, if you can sync we get this submitted?

@atgctg
Copy link
Contributor Author

atgctg commented Oct 10, 2023

No worries, synced!

@fabianp fabianp marked this pull request as ready for review December 5, 2023 09:27
@fabianp
Copy link
Member

fabianp commented Dec 5, 2023

@mtthss: this looks good to me. Do we have your green light to merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add an example of reading the learning rate from the optimizer state
3 participants