diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 023576cb5..9f0851897 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -1,8 +1,8 @@ # syntax=docker/dockerfile:1-labs -ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04 +ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04 ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com -ARG CLANG_VERSION=17 +ARG CLANG_VERSION=18 ############################################################################### ## Obtain GCP's NCCL TCPx plugin diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index c85bee347..726656a7a 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip # TransformerEngine now needs JAX at build time git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} pushd ${SRC_PATH_TRANSFORMER_ENGINE} +export NVTE_BUILD_THREADS_PER_JOB=8 python setup.py bdist_wheel && rm -rf build ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" EOF diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index fa4c055b8..8ff65ca99 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL # jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib pip list | grep jax +# Ensure directories are readable by all for non-root users +chmod 755 $BUILD_PATH_JAXLIB/* + ## Cleanup pushd $SRC_PATH_JAX diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index 73aee4163..f3e4e0715 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -17,11 +17,12 @@ apt-get clean rm -rf /var/lib/apt/lists/* -NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1 -if [[ -d "${NSYS202451}" ]]; then - # * can match at least sbsa-armv8 and x86 - (cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) -fi +for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do + if [[ -d "${NSYS}" ]]; then + # * can match at least sbsa-armv8 and x86 + (cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) + fi +done # Install extra dependencies needed for `nsys recipe ...` commands. These are # used by the nsys-jax wrapper script. diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 9e3aaee4f..4e72a33fb 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -6,7 +6,7 @@ import pathlib from typing import Any -from .protobuf import HloProto, xla_module_metadata +from .protobuf import HloProto, _host_memory_space, xla_module_metadata from .utils import make_child_mask, ProfilerData pd.options.mode.copy_on_write = True @@ -38,6 +38,11 @@ def align_profiler_data_timestamps( # Determine which collective size will be used for the alignment num_profiled_devices = len(comm_df.index.get_level_values("Device").unique()) max_collective_size = comm_df["CollectiveSize"].max() + if max_collective_size == 1: + print( + f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" + ) + return frames, {} assert ( num_profiled_devices == max_collective_size ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" @@ -193,13 +198,51 @@ def _get_message_size( "all-to-all", "collective-broadcast", "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", "reduce-scatter", } ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + + def _byte_size(inst) -> int: + size_bits = math.prod( + inst.shape.dimensions, + start=element_type_width(inst.shape.element_type), + ) + size_bytes, rem = divmod(size_bits, 8) + assert rem == 0 + return size_bytes + if comm_inst.opcode == "collective-permute-start": # See https://openxla.org/xla/operation_semantics#collectivepermute, which # generates pair-wise send+recv between devices collective_size = 2 + elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}: + # Label host-device transfers orchestrated by dynamic[-update]-slice as single + # device collectives. + collective_size = 1 + if comm_inst.opcode == "dynamic-update-slice": + # For dynamic-update-slice the second operand is the one being copied + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1]) + transfer_size = _byte_size(src_inst.proto()) + else: + # For dynamic-slice the return type size is the transfer size + assert comm_inst.opcode == "dynamic-slice" + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0]) + transfer_size = _byte_size(comm_inst) + dest_on_host = _host_memory_space(comm_inst) + src_on_host = _host_memory_space(src_inst.proto()) + assert src_on_host != dest_on_host, ( + 'dynamic[-update]-slice is only considered is only "communication" if it ' + "represents a host-device transfer" + ) + return ( + transfer_size, + "device-to-host" if dest_on_host else "host-to-device", + 1, # collective size + 1.0, # bw_correction + 1.0, # bus_correction + ) else: # replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8 # devices that are doing pair-wise collectives @@ -220,17 +263,12 @@ def _get_message_size( total_msg_size = 0 for operand_id in comm_inst.operand_ids: _, operand = module_proto.find_instruction_by_id(operand_id) - msg_size_bits = math.prod( - operand.proto().shape.dimensions, - start=element_type_width(operand.proto().shape.element_type), - ) + msg_size_bytes = _byte_size(operand.proto()) if comm_inst.opcode == "reduce-scatter": # NCCL's convention is that the message size of a reduce-scatter is the size of output buffer: # https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122 - msg_size_bits, rem = divmod(msg_size_bits, collective_size) + msg_size_bytes, rem = divmod(msg_size_bytes, collective_size) assert rem == 0 - msg_size_bytes, rem = divmod(msg_size_bits, 8) - assert rem == 0 total_msg_size += msg_size_bytes collective = comm_inst.opcode.removesuffix("-start") diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 6c25cb2ee..d6e4464bd 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -103,6 +103,9 @@ def is_communication(row): return _calculate_overlap(thunk_df) +compile_prefix = "XlaCompile:#module=" + + def _load_nvtx_gpu_proj_trace_single( prefix: pathlib.Path, file: pathlib.Path, @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single( unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates() if len(unique_pid_tid_pairs) == 1: main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) + # If the profile only includes N>1 modules, we may still be able to identify the + # main thread as the one responsible for XlaCompile ranges projected onto the GPU + # timeline + compile_ranges = df.loc[~all_thunks, "Name"].str.startswith( + tsl_prefix + compile_prefix + ) + compile_range_ids = compile_ranges[compile_ranges].index + unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates() + if len(unique_pid_tid_pairs) == 1: + main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) assert len(main_pid_tid_candidates) < 2 if len(main_pid_tid_candidates) == 1: # Possibly not correct if len(device_by_pid_tid) > 1 assert len(device_by_pid_tid) > 0 + # Associate the main thread with the 0th device in device_by_pid_tid main_thread_df = device_by_pid_tid.iloc[:1] main_thread_df.index = pd.MultiIndex.from_tuples( main_pid_tid_candidates, names=["PID", "TID"] @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace( return output -compile_prefix = "TSL:XlaCompile:#module=" - - def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame: # When parallel compilation is enabled, we end up with worker threads that # emit NVTX ranges but which are not accounted for in the RangeStack tree. # Splice these in under the relevant XlaCompile ranges in the RangeStack tree and # drop everything else. retain_mask = pd.Series(False, index=compile_df.index) - compile_mask = compile_df["Name"].str.startswith(compile_prefix) + compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix) for compile_range in compile_df[compile_mask].itertuples(): # Identify the slice of `compile_df` that overlaps in time with this XlaCompile # range diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py index ef74165fd..4feae6038 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py @@ -1,10 +1,13 @@ -from collections import defaultdict import functools import lzma import pathlib import typing +def _host_memory_space(inst): + return inst.shape.layout.memory_space == 5 + + class StackFrame(typing.NamedTuple): column: int file: str @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto): # proto representing the actual collective, which will be different if the # async launch is handled by an async-start op # TODO: can any of copy-start, custom-call, recv, send represent communication? + # This also aims to identify, and (for now) flag as communication, kernels that + # implement device-to-host and host-to-device copies for memory offloading. + # For example, a device-to-host offload might look like + # computation { + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...) + # } + # async_computation { + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) annotation shows that a buffer is in host memory. + # A host-to-device load might look like + # computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...) + # } + # async_computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) memory space annotation is in a parameter instead of in the + # return value. + # For now, handling host-device kernels as single-device "collective" + # communication should be sufficient. self._comm_proto = None comm_opcodes = { "all-gather", @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto): "all-reduce-start", "collective-permute-start", } + + def _is_offloading_instruction(inst): + host_dest = _host_memory_space(inst) + + def _host_operand(i): + _, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i]) + return _host_memory_space(op.proto()) + + if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0): + return True + elif ( + inst.opcode == "dynamic-update-slice" + and host_dest == _host_operand(0) + and host_dest != _host_operand(1) + ): + return True + return False + if self._proto.opcode in comm_opcodes | comm_start_opcodes: self._comm_proto = self._proto - elif self._proto.opcode == "async-start": + elif self._proto.opcode in {"async-start", "fusion"}: + # fusion example: + # computation { + # param_0 = f32[...]{...:S(5)} parameter(0) + # ... + # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...) + # } + # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation # This might be thinly wrapping an opcode in `comm_opcodes` - other_opcodes = defaultdict(int) - for called_id in self._proto.called_computation_ids: - for called_inst in wrapped_hlo_proto.find_computation( - called_id - ).instructions: - if called_inst.opcode in comm_opcodes: + def _visit_computation(computation_id): + computation = wrapped_hlo_proto.find_computation(computation_id) + for called_inst in computation.instructions: + for called_id in called_inst.called_computation_ids: + _visit_computation(called_id) + if called_inst.opcode in comm_opcodes or _is_offloading_instruction( + called_inst + ): assert ( self._comm_proto is None ), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" self._comm_proto = called_inst - else: - other_opcodes[called_inst.opcode] += 1 - assert ( - other_opcodes.keys() == {"parameter"} - ), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}" + + for called_id in self._proto.called_computation_ids: + _visit_computation(called_id) def communication_proto(self): return self._comm_proto @@ -68,12 +125,7 @@ def is_communication(self) -> bool: a little more complicated than you might hope, because async communications are not handled uniformly. """ - if self._comm_proto is None: - return False - assert ( - self._comm_proto.channel_id != 0 - ), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}" - return True + return self._comm_proto is not None def proto(self): """ diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 60ef1a001..e9d30a3bc 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -177,3 +177,8 @@ orbax-checkpoint: tracking_ref: main latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f mode: pip-vcs +pathwaysutils: + url: https://github.com/google/pathways-utils.git + tracking_ref: main + latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9 + mode: pip-vcs diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 164fa5912..0dc26c8c1 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true - --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization} diff --git a/.github/container/test-pax.sh b/.github/container/test-pax.sh index 2b33f53f7..46ce6ae73 100755 --- a/.github/container/test-pax.sh +++ b/.github/container/test-pax.sh @@ -15,7 +15,8 @@ usage() { echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py" echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4." echo " --dtype Batch size, defaults to bfloat16." - echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-cudnn-fa If set, will use cudnn fa." echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." echo " --disable-fused-attn Whether disable TE fused attention." echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" @@ -26,13 +27,13 @@ usage() { echo " --data-parallel Data parallelism to use. Defaults to 1." echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1." echo " --tensor-parallel Tensor parallelism to use. Defaults to 1." - echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." + echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." echo " -n, --nodes Number of nodes." echo " -h, --help Print usage." exit $1 } -args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") +args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") if [[ $? -ne 0 ]]; then exit $1 fi @@ -50,6 +51,7 @@ TP=1 PP=1 NODES=1 ENABLE_TE=0 +ENABLE_CUDNN_FA=0 MODEL_TYPE=126M NVTE_FUSED_ATTN=1 DROPOUT=0 @@ -75,6 +77,10 @@ while [ : ]; do ENABLE_TE=1 shift 1 ;; + --enable-cudnn-fa) + ENABLE_CUDNN_FA=1 + shift 1 + ;; --enable-dropout) DROPOUT='0.1' shift 1 @@ -128,7 +134,7 @@ while [ : ]; do ;; --) shift; - break + break ;; *) echo "UNKNOWN OPTION $1" @@ -149,6 +155,7 @@ print_var NGPUS print_var OUTPUT print_var MULTIPROCESS print_var ENABLE_TE +print_var ENABLE_CUDNN_FA print_var NVTE_FUSED_ATTN print_var EVALUATE print_var DROPOUT @@ -196,10 +203,10 @@ if dcn_factor > 1: if dp % dcn_factor == 0: dcn_dp = dcn_factor dp = int(dp / dcn_factor) - elif fsdp % dcn_factor == 0: + elif fsdp % dcn_factor == 0: dcn_fsdp = dcn_factor fsdp = int(fsdp / dcn_factor) - elif pp % dcn_factor == 0: + elif pp % dcn_factor == 0: dcn_pp = dcn_factor pp = int(pp / dcn_factor) @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): USE_REPEATED_LAYER = False ICI_MESH_SHAPE = [64,1,1] MAX_STEPS = 600000 - + MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 PACKED_INPUT = True PERCORE_BATCH_SIZE = 4 - + NUM_LAYERS = 12 NUM_HEADS = 12 MODEL_DIMS = 768 @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): TRAINABLE_POSITION_EMB = True TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN - + USE_BIAS = True LAYERNORM_EPSILON = 1e-5 ATTEN_LOGIT_CAP = -1.0 INIT_STD = 0.023 SOFTMAX_INIT_STD = 0.023 ACTIVATION_CLS = layers.GELU - + ## optimizer-related ADAM_BETA1 = 0.9 ADAM_BETA2 = 0.95 @@ -255,7 +262,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): ## disable eval to avoid including eval ## in steps/sec calculation EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() task_p = configure_gpt3_task(self, task_p) @@ -263,7 +270,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): task_p.train.num_train_steps = self.MAX_STEPS model_p = task_p.model - + ### compute layernorm reductions in fp32. Needed for stable training on GPUs stacked_p = model_p.lm_tpl.stacked_transformer_tpl if stacked_p.cls == layers.PipelinedTransformer: @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): transformer_layer_p.ln_tpl.reductions_in_fp32 = True transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True - + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) model_p.lm_tpl.softmax_tpl.params_init = softmax_init - + model_p.apply_eval_sample_weights = True - + ## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0 stacked_p.dropout_prob = 0.0 stacked_p.input_dropout_prob = self.DROPOUT_PROB @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset): if pp > 1: @experiment_registry.register class Synthetic126MCI(GPT126MPP, SyntheticDataset): - + ICI_MESH_SHAPE = [pp, dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1] MICROBATCH_SIZE = 2 NUM_STAGES = pp PERCORE_BATCH_SIZE = percore_batch_size FRPOP_DTYPE = dtype - + def task(self): task_p = super().task() task_p.train.always_use_train_for_model_init=False @@ -333,7 +340,7 @@ if pp > 1: else: @experiment_registry.register class Synthetic126MCI(Synthetic126M): - + ICI_MESH_SHAPE = [dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1] PERCORE_BATCH_SIZE = percore_batch_size @@ -343,7 +350,7 @@ else: ## disable eval EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model} +if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True" +fi + if [[ ${MODEL_TYPE} == "126M" ]]; then CONFIG=ci_configs.Synthetic126MCI elif [[ ${MODEL_TYPE} == "5B" ]]; then diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index fc04b83ab..426764323 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -11,6 +11,11 @@ on: description: 'Build date in YYYY-MM-DD format' required: false default: NOT SPECIFIED + CUDA_IMAGE: + type: string + description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04 + default: 'latest' + required: false MANIFEST_ARTIFACT_NAME: type: string description: 'Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch' @@ -37,6 +42,7 @@ jobs: uses: ./.github/workflows/_build_base.yaml with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} + BASE_IMAGE: ${{ inputs.CUDA_IMAGE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }} secrets: inherit diff --git a/.github/workflows/baselines/test_maxtext_metrics.py b/.github/workflows/baselines/test_maxtext_metrics.py index bd180ecfe..a130c86c6 100644 --- a/.github/workflows/baselines/test_maxtext_metrics.py +++ b/.github/workflows/baselines/test_maxtext_metrics.py @@ -19,7 +19,7 @@ def test_loss(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: end_step = json.load(baseline_file)["end_step"] @@ -31,7 +31,7 @@ def test_loss(baseline_filename): def test_step_time(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: step_time_avg_expected = json.load(baseline_file)["step_time_avg"] diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 70aeff5ff..0c3c8bdb0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,6 +28,11 @@ on: description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch" default: false required: false + CUDA_IMAGE: + type: string + description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04 + default: 'latest' + required: false SOURCE_OVERRIDES: type: string description: | @@ -60,6 +65,7 @@ jobs: MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }} MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }} MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }} + CUDA_IMAGE: ${{ steps.cuda-image.outputs.CUDA_IMAGE }} steps: - name: Cancel workflow run if the trigger is a draft PR id: cancel-if-draft @@ -114,6 +120,17 @@ jobs: exit 1 fi + - name: Determine CUDA image to use + id: cuda-image + shell: bash -x -e {0} + run: | + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + CUDA_IMAGE="${{ inputs.CUDA_IMAGE }}" + else + CUDA_IMAGE="latest" + fi + echo "CUDA_IMAGE=${CUDA_IMAGE}" >> $GITHUB_OUTPUT + bump-manifest: needs: metadata runs-on: ubuntu-22.04 @@ -177,6 +194,7 @@ jobs: with: ARCHITECTURE: amd64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} secrets: inherit @@ -187,6 +205,7 @@ jobs: with: ARCHITECTURE: arm64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} secrets: inherit diff --git a/README.md b/README.md index 66d9b2a4e..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index dd3aa1bae..069b06fdd 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,13 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -125,8 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 02e5f5294..2425ddffe 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true ---xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index fde5a9125..2320a7ed9 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,12 +67,9 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true + --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -80,7 +77,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index e96eaa781..0ca3fd802 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,11 +53,8 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -65,12 +62,9 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false - --xla_disable_hlo_passes=rematerialization - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false" + --xla_disable_hlo_passes=rematerialization" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env new file mode 100644 index 000000000..d999f5b5e --- /dev/null +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -0,0 +1,24 @@ +set -x +NUM_NODES=1 +NUM_GPUS=8 +THRESHOLD_BYTES=1073741824 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_gpu_enable_triton_gemm=false \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset NUM_NODES NUM_GPUS THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 6ac4dc150..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows: ``` -BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... +BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false + --xla_gpu_all_reduce_combine_threshold_bytes=33554432 + --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ``` # Configs diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env new file mode 100644 index 000000000..26c819143 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -0,0 +1,13 @@ +set -x +THRESHOLD_BYTES=51200 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env new file mode 100644 index 000000000..e5b97b466 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -0,0 +1,14 @@ +set -x +THRESHOLD_BYTES=33554432 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_cudnn_fmha=false \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env new file mode 100644 index 000000000..e48b76dcf --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -0,0 +1,25 @@ +set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + --xla_gpu_enable_custom_fusions=true + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env new file mode 100644 index 000000000..d1568e92c --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/t5x/xla_flags/t5.env b/rosetta/rosetta/projects/t5x/xla_flags/t5.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/t5x/xla_flags/t5.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env new file mode 100644 index 000000000..45140ed88 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env new file mode 100644 index 000000000..882c9e9e8 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +set +x