Skip to content

Commit

Permalink
Collect recursively and filter GPU tests using jax_test_gpu tag (#1091
Browse files Browse the repository at this point in the history
)
  • Loading branch information
andportnoy authored Oct 21, 2024
1 parent 0832fac commit cfc3f74
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ for t in $*; do
BAZEL_TARGET="${BAZEL_TARGET} $t"
done

TEST_TAG_FILTER_ARRAY=()
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')

COMMON_FLAGS=$(cat << EOF
--@local_config_cuda//:enable_cuda
--cache_test_results=${CACHE_TEST_RESULTS}
--test_timeout=600
--test_tag_filters=-multiaccelerator
--test_env=JAX_SKIP_SLOW_TESTS=1
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
Expand All @@ -138,7 +140,11 @@ case "${BATTERY}" in
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
# collect from all tests subdirectories recursively,
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
TEST_TAG_FILTER_ARRAY+=('jax_test_gpu')
BAZEL_TARGET="${BAZEL_TARGET} //tests/..."
;;
backend-independent)
JOBS_PER_GPU=4
Expand All @@ -157,6 +163,8 @@ case "${BATTERY}" in
;;
esac

TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")

print_var NCPUS
print_var NGPUS
print_var BATTERY
Expand All @@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
print_var JOBS
print_var BUILD_JAXLIB
print_var BAZEL_TARGET
print_var TEST_TAG_FILTERS
print_var COMMON_FLAGS
print_var EXTRA_FLAGS

Expand All @@ -182,4 +191,4 @@ pip install matplotlib

cd `jax_source_dir`
python build/build.py --configure_only
bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}

0 comments on commit cfc3f74

Please sign in to comment.