Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to test-pax.sh to enable XLA cuDNN flash attention #1045

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ usage() {
echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-cudnn-fa If set, will use cudnn fa."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --disable-fused-attn Whether disable TE fused attention."
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
Expand All @@ -26,13 +27,13 @@ usage() {
echo " --data-parallel Data parallelism to use. Defaults to 1."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " -n, --nodes Number of nodes."
echo " -h, --help Print usage."
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -50,6 +51,7 @@ TP=1
PP=1
NODES=1
ENABLE_TE=0
ENABLE_CUDNN_FA=0
MODEL_TYPE=126M
NVTE_FUSED_ATTN=1
DROPOUT=0
Expand All @@ -75,6 +77,10 @@ while [ : ]; do
ENABLE_TE=1
shift 1
;;
--enable-cudnn-fa)
ENABLE_CUDNN_FA=1
shift 1
;;
--enable-dropout)
DROPOUT='0.1'
shift 1
Expand Down Expand Up @@ -128,7 +134,7 @@ while [ : ]; do
;;
--)
shift;
break
break
;;
*)
echo "UNKNOWN OPTION $1"
Expand All @@ -149,6 +155,7 @@ print_var NGPUS
print_var OUTPUT
print_var MULTIPROCESS
print_var ENABLE_TE
print_var ENABLE_CUDNN_FA
print_var NVTE_FUSED_ATTN
print_var EVALUATE
print_var DROPOUT
Expand Down Expand Up @@ -196,10 +203,10 @@ if dcn_factor > 1:
if dp % dcn_factor == 0:
dcn_dp = dcn_factor
dp = int(dp / dcn_factor)
elif fsdp % dcn_factor == 0:
elif fsdp % dcn_factor == 0:
dcn_fsdp = dcn_factor
fsdp = int(fsdp / dcn_factor)
elif pp % dcn_factor == 0:
elif pp % dcn_factor == 0:
dcn_pp = dcn_factor
pp = int(pp / dcn_factor)

Expand All @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
USE_REPEATED_LAYER = False
ICI_MESH_SHAPE = [64,1,1]
MAX_STEPS = 600000

MAX_SEQ_LEN = 2048
VOCAB_SIZE = 50304
PACKED_INPUT = True
PERCORE_BATCH_SIZE = 4

NUM_LAYERS = 12
NUM_HEADS = 12
MODEL_DIMS = 768
Expand All @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):

TRAINABLE_POSITION_EMB = True
TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN

USE_BIAS = True
LAYERNORM_EPSILON = 1e-5
ATTEN_LOGIT_CAP = -1.0
INIT_STD = 0.023
SOFTMAX_INIT_STD = 0.023
ACTIVATION_CLS = layers.GELU

## optimizer-related
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.95
Expand All @@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
## disable eval to avoid including eval
## in steps/sec calculation
EVAL_INTERVAL_STEPS = 100000

def task(self):
task_p = super().task()
task_p = configure_gpt3_task(self, task_p)

task_p.train.num_train_steps = self.MAX_STEPS

model_p = task_p.model

### compute layernorm reductions in fp32. Needed for stable training on GPUs
stacked_p = model_p.lm_tpl.stacked_transformer_tpl
if stacked_p.cls == layers.PipelinedTransformer:
Expand All @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True

model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
model_p.lm_tpl.softmax_tpl.params_init = softmax_init

model_p.apply_eval_sample_weights = True

## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
stacked_p.dropout_prob = 0.0
stacked_p.input_dropout_prob = self.DROPOUT_PROB
Expand Down Expand Up @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
if pp > 1:
@experiment_registry.register
class Synthetic126MCI(GPT126MPP, SyntheticDataset):

ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
MICROBATCH_SIZE = 2
NUM_STAGES = pp
PERCORE_BATCH_SIZE = percore_batch_size
FRPOP_DTYPE = dtype

def task(self):
task_p = super().task()
task_p.train.always_use_train_for_model_init=False
Expand All @@ -333,7 +340,7 @@ if pp > 1:
else:
@experiment_registry.register
class Synthetic126MCI(Synthetic126M):

ICI_MESH_SHAPE = [dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
PERCORE_BATCH_SIZE = percore_batch_size
Expand All @@ -343,7 +350,7 @@ else:

## disable eval
EVAL_INTERVAL_STEPS = 100000

def task(self):
task_p = super().task()

Expand Down Expand Up @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model}

if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
fi

if [[ ${MODEL_TYPE} == "126M" ]]; then
CONFIG=ci_configs.Synthetic126MCI
elif [[ ${MODEL_TYPE} == "5B" ]]; then
Expand Down
Loading