From 8e8320f7bf8b99a88cd72a0e156540bfde3b67ad Mon Sep 17 00:00:00 2001 From: Harshit Monish <143435143+hmonishN@users.noreply.github.com> Date: Fri, 8 Mar 2024 09:44:30 -0800 Subject: [PATCH] Add FMHA T5x test (#442) Adding the JAX T5x FMHA E2E system test to check for fmha lowering support. Following are the steps implemented in the test: FMHA lowering flag is enabled by default now, enabled the dumping of hlo to track fmha forward and backward instructions. Added the test as part of _ci.yaml file and also added a nightly workflow file for it. We will add this test as part of performance benchmarking later and add hlo to baseline. Also added changes for correction of seq length of decoder (should be a multiple of 64) The test was failing with following error related to CUDNN_STATUS_BAD_PARAM. The fix for this is added in the [PR] (https://github.com/openxla/xla/pull/6872) in upstream which is now merged and the test passes. [Bug](https://nvbugspro.nvidia.com/bug/4409713) for this error. run for these changes: [workflow run link](https://github.com/NVIDIA/JAX-Toolbox/actions/runs/7894631992) --------- Co-authored-by: Terry Kong --- .github/container/test-t5x.sh | 55 +++++++++++++++- .github/workflows/_test_upstream_t5x.yaml | 76 ++++++++++++++++++++--- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/.github/container/test-t5x.sh b/.github/container/test-t5x.sh index 573834b8d..942e4b2c4 100755 --- a/.github/container/test-t5x.sh +++ b/.github/container/test-t5x.sh @@ -20,13 +20,15 @@ usage() { echo " -e, --epochs Number of epochs to run, defaults to 7." echo " --multiprocess Enable the multiprocess GPU mode." echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified." + echo " --save-hlo {0, 1} 1 to save the dumped hlo, 0 to remove the hlo dumped folder" echo " --seed INT Random seed for deterministim. Defaults to 42." echo " -s, --steps-per-epoch INT Steps per epoch. Detauls to 100" + echo " --enable-fmha {0, 1} 1 to enable fmha testing, 0 to run test without fmha; default is 0" echo " -h, --help Print usage." exit $1 } -args=$(getopt -o a:b:cd:e:ho:s: --long additional-args:,batch-size:,use-contrib-configs,dtype:,enable-te:,epochs:,help,multiprocess,output:,seed:,steps-per-epoch: -- "$@") +args=$(getopt -o a:b:cd:e:ho:s: --long additional-args:,batch-size:,use-contrib-configs,dtype:,enable-te:,enable-fmha:,epochs:,help,multiprocess,output:,seed:,save-hlo:,steps-per-epoch: -- "$@") if [[ $? -ne 0 ]]; then exit 1 fi @@ -43,6 +45,8 @@ OUTPUT=$(mktemp -d) SEED=42 STEPS_PER_EPOCH=100 ENABLE_TE=${ENABLE_TE:-0} +ENABLE_FMHA=${ENABLE_FMHA:-0} +SAVE_HLO=${SAVE_HLO:-1} eval set -- "$args" while [ : ]; do @@ -67,6 +71,10 @@ while [ : ]; do ENABLE_TE="$2" shift 2 ;; + --enable-fmha) + ENABLE_FMHA="$2" + shift 2 + ;; -e | --epochs) EPOCHS="$2" shift 2 @@ -82,6 +90,10 @@ while [ : ]; do OUTPUT="$2" shift 2 ;; + --save-hlo) + SAVE_HLO="$2" + shift 2 + ;; --seed) SEED="$2" shift 2 @@ -105,6 +117,20 @@ if [[ $BATCH_SIZE == 0 ]]; then usage 1 fi +# Set hlo dump folder after output folder is set. +HLO_DIR=${OUTPUT}/hlo +export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_dump_hlo_as_text --xla_dump_to=${HLO_DIR}}" +export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}" +echo "HLO will be dumped in ${HLO_DIR} dir." + +## Setting the env variables for FMHA +if [[ "$ENABLE_FMHA" -eq "1" ]]; then + echo "Setting XLA FMHA Flags"; + export BASE_XLA_FLAGS_FMHA="${BASE_XLA_FLAGS_FMHA:---xla_gpu_fused_attention_use_cudnn_rng=true --xla_gpu_enable_cudnn_fmha=true}" + export XLA_FLAGS="${BASE_XLA_FLAGS_FMHA} ${XLA_FLAGS:-}" +fi + +echo "XLA FLAGS: $XLA_FLAGS" ## Set derived variables TRAIN_STEPS=$(($EPOCHS * $STEPS_PER_EPOCH)) @@ -114,11 +140,13 @@ print_var BATCH_SIZE print_var USE_CONTRIB_CONFIGS print_var DTYPE print_var ENABLE_TE +print_var ENABLE_FMHA print_var EPOCHS print_var OUTPUT print_var MULTIPROCESS print_var STEPS_PER_EPOCH print_var TRAIN_STEPS +print_var SAVE_HLO ## Enter T5X source folder T5X_DIR=$(dirname `python -c 'import t5x; print(*t5x.__path__)'`) @@ -178,7 +206,7 @@ $( import dummy_wikipedia MIXTURE_OR_TASK_NAME = "dummy_wikipedia" -TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114} +TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 128} DROPOUT_RATE = 0.0 USE_CACHED_TASKS = False TRAIN_STEPS = %gin.REQUIRED @@ -206,3 +234,26 @@ ENABLE_TE=$ENABLE_TE python -m t5x.train \ $ADDITIONAL_ARGS \ $([[ $MULTIPROCESS != 0 ]] && echo --multiprocess_gpu) echo "Output at ${OUTPUT}" + +if [[ "$ENABLE_FMHA" -eq "1" ]]; then + ## Check if fmha instructions are present in the HLO dumped file or not. + fmha_regex="fmha[-bmm]?[-scale]?[-bias]?[-mask]?[-softmax]?[-dropout]?[-bmm]?[-backward]?*" + result=$(grep -irlnE "$fmha_regex" "${HLO_DIR}/"*.txt) + + if [[ $SAVE_HLO -eq 0 ]]; then + rm -rf $HLO_DIR + echo "Removed dumped HLO directory!" + fi + + if [ -z "$result" ]; then + echo "E: No FMHA instructions were found in the hlo files!" + exit 1 + else + echo -e "Found FMHA instructions in the following HLO files: \n $result" + fi +else + if [[ $SAVE_HLO -eq 0 ]]; then + rm -rf $HLO_DIR + echo "Removed dumped HLO directory!" + fi +fi diff --git a/.github/workflows/_test_upstream_t5x.yaml b/.github/workflows/_test_upstream_t5x.yaml index 859290d40..c5211bbde 100644 --- a/.github/workflows/_test_upstream_t5x.yaml +++ b/.github/workflows/_test_upstream_t5x.yaml @@ -38,7 +38,24 @@ jobs: t5x-multi-gpu: strategy: matrix: - N_GPU: [1, 2, 4, 8] + include: + - TEST_NAME: "1P1G" + N_GPU: 1 + ADDITIONAL_ARGS: "" + - TEST_NAME: "1P2G" + N_GPU: 2 + ADDITIONAL_ARGS: "" + - TEST_NAME: "1P4G" + N_GPU: 4 + ADDITIONAL_ARGS: "" + - TEST_NAME: "1P8G" + N_GPU: 8 + - TEST_NAME: "1P1G_fmha" + N_GPU: 1 + ADDITIONAL_ARGS: "--enable-fmha 1" + - TEST_NAME: "1P2G_fmha" + N_GPU: 2 + ADDITIONAL_ARGS: "--enable-fmha 1" fail-fast: false runs-on: ubuntu-22.04 @@ -70,7 +87,7 @@ jobs: shell: bash -x -e {0} run: | IMAGE="$(echo ${{inputs.T5X_IMAGE}} | sed 's/\//#/')" - TEST_CASE_NAME=1P${{ matrix.N_GPU }}G + TEST_CASE_NAME=${{ matrix.TEST_NAME }} JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME} LOG_FILE=/nfs/cluster/${JOB_NAME}.log MODEL_PATH=/nfs/cluster/${JOB_NAME} @@ -114,10 +131,11 @@ jobs: --dtype bfloat16 \ --batch-size ${{ steps.meta.outputs.BATCH_SIZE }} \ --epochs 7 \ - --steps-per-epoch 100 + --steps-per-epoch 100 \ + ${{ matrix.ADDITIONAL_ARGS }} EOF ) - + echo "SLURM_JOB_ID=${JOB}" >> $GITHUB_OUTPUT . .github/workflows/scripts/wait_for_slurm_job.sh @@ -174,8 +192,47 @@ jobs: t5x-multi-node: strategy: matrix: - N_GPU: [1, 2, 4, 8] - N_NODE: [1, 2] + include: + - TEST_NAME: "1G1N" + N_GPU: 1 + N_NODE: 1 + ADDITIONAL_ARGS: "" + - TEST_NAME: "2G1N" + N_GPU: 2 + N_NODE: 1 + ADDITIONAL_ARGS: "" + - TEST_NAME: "4G1N" + N_GPU: 4 + N_NODE: 1 + ADDITIONAL_ARGS: "" + - TEST_NAME: "8G1N" + N_GPU: 8 + N_NODE: 1 + ADDITIONAL_ARGS: "" + - TEST_NAME: "1G2N" + N_GPU: 1 + N_NODE: 2 + ADDITIONAL_ARGS: "" + - TEST_NAME: "2G2N" + N_GPU: 2 + N_NODE: 2 + ADDITIONAL_ARGS: "" + - TEST_NAME: "4G2N" + N_GPU: 4 + N_NODE: 2 + ADDITIONAL_ARGS: "" + - TEST_NAME: "8G2N" + N_GPU: 8 + N_NODE: 2 + ADDITIONAL_ARGS: "" + - TEST_NAME: "2G2N_fmha" + N_GPU: 2 + N_NODE: 2 + ADDITIONAL_ARGS: "--enable-fmha 1" + - TEST_NAME: "8G2N_fmha" + N_GPU: 8 + N_NODE: 2 + ADDITIONAL_ARGS: "--enable-fmha 1" fail-fast: false runs-on: ubuntu-22.04 @@ -207,9 +264,9 @@ jobs: shell: bash -x -e {0} run: | IMAGE="$(echo ${{inputs.T5X_IMAGE}} | sed 's/\//#/')" - TEST_CASE_NAME=${{ matrix.N_GPU }}G${{ matrix.N_NODE }}N + TEST_CASE_NAME=${{ matrix.TEST_NAME }} TOTAL_TASKS=$((${{ matrix.N_GPU }} * ${{ matrix.N_NODE }})) - JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME} + JOB_NAME=${{ inputs.FW_NAME }}-${GITHUB_RUN_ID}-${TEST_CASE_NAME}; LOG_FILE=/nfs/cluster/${JOB_NAME}.log MODEL_PATH=/nfs/cluster/${JOB_NAME} BATCH_SIZE=$((${{ inputs.BATCH_SIZE_PER_GPU }} * ${{ matrix.N_GPU }} * ${{ matrix.N_NODE }})) @@ -254,7 +311,8 @@ jobs: --batch-size ${{ steps.meta.outputs.BATCH_SIZE }} \ --epochs 7 \ --steps-per-epoch 100 \ - --multiprocess + --multiprocess \ + ${{ matrix.ADDITIONAL_ARGS }} EOF )