Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: changed the assertions to make sure the num_updates is a multiple of num_evaluations #1083

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

Louay-Ben-nessir
Copy link
Contributor

What?

Changed the assertions to make sure the num_updates is a multiple of num_evaluations.

Why?

Only num_evaluation * num_updates_per_eval are ran while training which can lead to some missed updates if the num_updates is not a multiple of num_evaluations.

How?

changed the assertions.

OmaymaMahjoub
OmaymaMahjoub previously approved these changes Jul 2, 2024
Copy link
Contributor

@OmaymaMahjoub OmaymaMahjoub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Louay-Ben-nessir, can you as well check if the other systems suffer from the same problem or not 🙏

@sash-a
Copy link
Contributor

sash-a commented Jul 3, 2024

A suggestion here that would remove the need for the assert and make mava easier to configure is to change the variable to evaluation frequency and then store num_evaluations in the config as config.arch.num_evals = config.system.num_updates // config.arch.eval_freq so that we still have that info when logging

@Louay-Ben-nessir
Copy link
Contributor Author

Thanks @Louay-Ben-nessir, can you as well check if the other systems suffer from the same problem or not 🙏

I think this issue is exclusive to ppo systems

A suggestion here that would remove the need for the assert and make mava easier to configure is to change the variable to evaluation frequency and then store num_evaluations in the config as config.arch.num_evals = config.system.num_updates // config.arch.eval_freq so that we still have that info when logging

This a huge improvement over the current implementation but it's still not exact in some cases. losing some updates is worth it for the flexibility tho so I'll change it.

@sash-a
Copy link
Contributor

sash-a commented Jul 3, 2024

Ah right we could lose some updates. @RuanJohn has had an issue with this in the past, so maybe a jnp.ceil is needed here, just double check with him

Copy link
Collaborator

@RuanJohn RuanJohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hard assert here is a bit too strict in my opinion.
Something we could do is to make a warning that says the number of timesteps someone is assuming their experiment will run for might not happen and then give the total number of timesteps that will run.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants