Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX unit test concurrency should be set by CPU count, not GPU count #228

Open
yhtang opened this issue Sep 13, 2023 · 1 comment
Open

JAX unit test concurrency should be set by CPU count, not GPU count #228

yhtang opened this issue Sep 13, 2023 · 1 comment

Comments

@yhtang
Copy link
Collaborator

yhtang commented Sep 13, 2023

Pending confirmation from other developers, it seems that the //tests:gpu_tests tests are (counterintuitive) bottlenecked by CPU concurrency and performance. Hence, the best performance for JAX unit testing may be achieved by adjusting the Bazel job count according to number of CPU cores available, rather than the number of GPUs available.

@nouiz
Copy link
Collaborator

nouiz commented Sep 13, 2023

It is more complicated then this. Using the CPU cores will cause too many contexts and slowdown tests and crashes others.
We need to parallelize per GPUs, but we can overload a little bit each GPU if there is enough GPU memory. Like 2 or 4 tests per GPUs. Also, the more GPUs are in the box, the less jobs per GPUs we can run, as there are limitation in the driver. Newer driver allow more context in total to be created.

So mostly, we need a mix of both.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

No branches or pull requests

2 participants