diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 7e8607dfb..b1f97fa0a 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -48,6 +48,7 @@ praxis: patches: pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas. pull/36/head: file://patches/praxis/PR-36.patch # adds Transformer Engine support + pull/74/head: file://patches/praxis/PR-74.patch # experimental support for using TE FMHA in GQA lingvo: # Used only in ARM pax builds url: https://github.com/tensorflow/lingvo.git @@ -179,4 +180,4 @@ orbax-checkpoint: url: https://github.com/google/orbax.git tracking_ref: main latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f - mode: pip-vcs \ No newline at end of file + mode: pip-vcs diff --git a/.github/container/patches/flax/PR-3340.patch b/.github/container/patches/flax/PR-3340.patch index d19f134be..5210ef986 100644 --- a/.github/container/patches/flax/PR-3340.patch +++ b/.github/container/patches/flax/PR-3340.patch @@ -1,16 +1,16 @@ -From d748ab4447dbb82ea9317f71211a3bbd9ba4207f Mon Sep 17 00:00:00 2001 +From 97ad32bf809e1e7b1715747c942213031f88fcc6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 2 Jun 2023 15:01:21 -0700 Subject: [PATCH 1/3] add t5x sharding annotations to flax layers --- - flax/linen/attention.py | 34 +++++++++++++++++++++++------- + flax/linen/attention.py | 33 ++++++++++++++++++++++------- flax/linen/linear.py | 41 ++++++++++++++++++++++++++++--------- flax/linen/normalization.py | 25 ++++++++++++++++++---- - 3 files changed, 79 insertions(+), 21 deletions(-) + 3 files changed, 78 insertions(+), 21 deletions(-) diff --git a/flax/linen/attention.py b/flax/linen/attention.py -index efcf2b78..689ce4da 100644 +index 99b79f2d..bcabd554 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -30,6 +30,7 @@ from flax.linen.linear import ( @@ -27,12 +27,12 @@ index efcf2b78..689ce4da 100644 - def dot_product_attention_weights( - query: Array, - key: Array, -@@ -287,6 +287,17 @@ class MultiHeadDotProductAttention(Module): + query: Array, + key: Array, +@@ -313,6 +313,16 @@ class MultiHeadDotProductAttention(Module): num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). + decode: Whether to prepare and use an autoregressive cache. + normalize_qk: Should QK normalization be applied (arxiv.org/abs/2302.05442). + in_proj_kernel_axes: a tuple of axes over which to shard the kernel for + the attention in-projection. + in_proj_bias_axes: a tuple of axis names associated with the bias for @@ -43,11 +43,10 @@ index efcf2b78..689ce4da 100644 + the attention out-projection. + decode_axes: a tuple of axis names associated with auroregressive cache. + Only used when decode=True. -+ """ num_heads: int -@@ -309,6 +320,11 @@ class MultiHeadDotProductAttention(Module): +@@ -336,6 +346,11 @@ class MultiHeadDotProductAttention(Module): out_dot_general: Optional[DotGeneralT] = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None @@ -59,7 +58,7 @@ index efcf2b78..689ce4da 100644 @overload def __call__( -@@ -447,6 +463,8 @@ class MultiHeadDotProductAttention(Module): +@@ -474,6 +489,8 @@ class MultiHeadDotProductAttention(Module): precision=self.precision, dot_general=self.qkv_dot_general, dot_general_cls=self.qkv_dot_general_cls, @@ -68,7 +67,7 @@ index efcf2b78..689ce4da 100644 ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] -@@ -477,14 +495,14 @@ class MultiHeadDotProductAttention(Module): +@@ -504,14 +521,14 @@ class MultiHeadDotProductAttention(Module): if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') @@ -89,7 +88,7 @@ index efcf2b78..689ce4da 100644 ) if is_initialized: ( -@@ -580,6 +598,8 @@ class MultiHeadDotProductAttention(Module): +@@ -607,6 +624,8 @@ class MultiHeadDotProductAttention(Module): dot_general=self.out_dot_general, dot_general_cls=self.out_dot_general_cls, name='out', # type: ignore[call-arg] @@ -99,18 +98,18 @@ index efcf2b78..689ce4da 100644 return out diff --git a/flax/linen/linear.py b/flax/linen/linear.py -index 36365ea1..4656abf9 100644 +index e4901293..27e22325 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py -@@ -35,6 +35,7 @@ from flax.core import meta - from flax.linen import initializers +@@ -37,6 +37,7 @@ from flax.linen import initializers from flax.linen.dtypes import promote_dtype + from flax.linen import module from flax.linen.module import Module, compact +from flax.linen.partitioning import param_with_axes from flax.typing import ( Array, PRNGKey as PRNGKey, -@@ -97,6 +98,8 @@ class DenseGeneral(Module): +@@ -99,6 +100,8 @@ class DenseGeneral(Module): bias_init: initializer function for the bias. precision: numerical precision of the computation see ``jax.lax.Precision`` for details. @@ -119,7 +118,7 @@ index 36365ea1..4656abf9 100644 """ features: Union[int, Sequence[int]] -@@ -111,6 +114,8 @@ class DenseGeneral(Module): +@@ -113,6 +116,8 @@ class DenseGeneral(Module): # Deprecated. Will be removed. dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @@ -128,7 +127,7 @@ index 36365ea1..4656abf9 100644 @compact def __call__(self, inputs: Array) -> Array: -@@ -159,8 +164,9 @@ class DenseGeneral(Module): +@@ -161,8 +166,9 @@ class DenseGeneral(Module): if ax not in axis ) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features @@ -140,7 +139,7 @@ index 36365ea1..4656abf9 100644 ) batch_ind = tuple(range(n_batch_dims)) -@@ -178,9 +184,11 @@ class DenseGeneral(Module): +@@ -180,9 +186,11 @@ class DenseGeneral(Module): return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape)) return jnp.reshape(bias, shape) @@ -154,7 +153,7 @@ index 36365ea1..4656abf9 100644 else: bias = None -@@ -228,6 +236,8 @@ class Dense(Module): +@@ -230,6 +238,8 @@ class Dense(Module): for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. @@ -163,7 +162,7 @@ index 36365ea1..4656abf9 100644 """ features: int -@@ -240,6 +250,8 @@ class Dense(Module): +@@ -242,6 +252,8 @@ class Dense(Module): # Deprecated. Will be removed. dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None @@ -172,7 +171,7 @@ index 36365ea1..4656abf9 100644 @compact def __call__(self, inputs: Array) -> Array: -@@ -251,15 +263,18 @@ class Dense(Module): +@@ -253,15 +265,18 @@ class Dense(Module): Returns: The transformed input. """ @@ -194,7 +193,7 @@ index 36365ea1..4656abf9 100644 ) else: bias = None -@@ -351,6 +366,8 @@ class _Conv(Module): +@@ -474,6 +489,8 @@ class _Conv(Module): for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. @@ -203,7 +202,7 @@ index 36365ea1..4656abf9 100644 """ features: int -@@ -370,6 +387,8 @@ class _Conv(Module): +@@ -493,6 +510,8 @@ class _Conv(Module): # Deprecated. Will be removed. conv_general_dilated: Optional[ConvGeneralDilatedT] = None conv_general_dilated_cls: Any = None @@ -212,7 +211,7 @@ index 36365ea1..4656abf9 100644 @property def shared_weights(self) -> bool: # type: ignore -@@ -511,8 +530,10 @@ class _Conv(Module): +@@ -634,8 +653,10 @@ class _Conv(Module): f'Shapes are: {self.mask.shape}, {kernel_shape}' ) @@ -225,7 +224,7 @@ index 36365ea1..4656abf9 100644 ) if self.mask is not None: -@@ -526,7 +547,7 @@ class _Conv(Module): +@@ -649,7 +670,7 @@ class _Conv(Module): # One bias weight per output entry, unshared betwen pixels. bias_shape = conv_output_shape[1:] @@ -235,7 +234,7 @@ index 36365ea1..4656abf9 100644 bias = None diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py -index abfbfb5a..bab40243 100644 +index 0680737f..ed241d08 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -24,6 +24,7 @@ from jax import lax @@ -246,7 +245,7 @@ index abfbfb5a..bab40243 100644 from flax.typing import ( Array, PRNGKey as PRNGKey, -@@ -154,6 +155,7 @@ def _normalize( +@@ -159,6 +160,7 @@ def _normalize( use_scale: bool, bias_init: Initializer, scale_init: Initializer, @@ -254,7 +253,7 @@ index abfbfb5a..bab40243 100644 ): """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. -@@ -173,6 +175,7 @@ def _normalize( +@@ -178,6 +180,7 @@ def _normalize( use_scale: If true, scale the output. bias_init: Initialization function for the bias term. scale_init: Initialization function for the scaling function. @@ -262,7 +261,7 @@ index abfbfb5a..bab40243 100644 Returns: The normalized input. -@@ -191,15 +194,17 @@ def _normalize( +@@ -196,15 +199,17 @@ def _normalize( mul = lax.rsqrt(var + epsilon) args = [x] if use_scale: @@ -284,7 +283,7 @@ index abfbfb5a..bab40243 100644 ).reshape(feature_shape) y += bias args.append(bias) -@@ -283,6 +288,7 @@ class BatchNorm(Module): +@@ -289,6 +294,7 @@ class BatchNorm(Module): more details. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. @@ -292,15 +291,15 @@ index abfbfb5a..bab40243 100644 """ use_running_average: Optional[bool] = None -@@ -298,6 +304,7 @@ class BatchNorm(Module): - axis_name: Optional[str] = None +@@ -305,6 +311,7 @@ class BatchNorm(Module): axis_index_groups: Any = None use_fast_variance: bool = True + force_float32_reductions: bool = True + pjit_axis_name: Tuple[str, ...] = None @compact def __call__( -@@ -377,6 +384,7 @@ class BatchNorm(Module): +@@ -385,6 +392,7 @@ class BatchNorm(Module): self.use_scale, self.bias_init, self.scale_init, @@ -308,7 +307,7 @@ index abfbfb5a..bab40243 100644 ) -@@ -439,6 +447,7 @@ class LayerNorm(Module): +@@ -448,6 +456,7 @@ class LayerNorm(Module): more details. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. @@ -316,15 +315,15 @@ index abfbfb5a..bab40243 100644 """ epsilon: float = 1e-6 -@@ -453,6 +462,7 @@ class LayerNorm(Module): - axis_name: Optional[str] = None +@@ -463,6 +472,7 @@ class LayerNorm(Module): axis_index_groups: Any = None use_fast_variance: bool = True + force_float32_reductions: bool = True + pjit_axis_name: Tuple[str, ...] = None @compact def __call__(self, x, *, mask: Optional[jax.Array] = None): -@@ -490,6 +500,7 @@ class LayerNorm(Module): +@@ -501,6 +511,7 @@ class LayerNorm(Module): self.use_scale, self.bias_init, self.scale_init, @@ -332,7 +331,7 @@ index abfbfb5a..bab40243 100644 ) -@@ -538,6 +549,7 @@ class RMSNorm(Module): +@@ -549,6 +560,7 @@ class RMSNorm(Module): more details. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. @@ -340,15 +339,15 @@ index abfbfb5a..bab40243 100644 """ epsilon: float = 1e-6 -@@ -550,6 +562,7 @@ class RMSNorm(Module): - axis_name: Optional[str] = None +@@ -562,6 +574,7 @@ class RMSNorm(Module): axis_index_groups: Any = None use_fast_variance: bool = True + force_float32_reductions: bool = True + pjit_axis_name: Tuple[str, ...] = None @compact def __call__(self, x, *, mask: Optional[jax.Array] = None): -@@ -588,6 +601,7 @@ class RMSNorm(Module): +@@ -601,6 +614,7 @@ class RMSNorm(Module): self.use_scale, initializers.zeros, self.scale_init, @@ -356,7 +355,7 @@ index abfbfb5a..bab40243 100644 ) -@@ -657,6 +671,7 @@ class GroupNorm(Module): +@@ -671,6 +685,7 @@ class GroupNorm(Module): more details. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. @@ -364,15 +363,15 @@ index abfbfb5a..bab40243 100644 """ num_groups: Optional[int] = 32 -@@ -672,6 +687,7 @@ class GroupNorm(Module): - axis_name: Optional[str] = None +@@ -687,6 +702,7 @@ class GroupNorm(Module): axis_index_groups: Any = None use_fast_variance: bool = True + force_float32_reductions: bool = True + pjit_axis_name: Tuple[str, ...] = None @compact def __call__(self, x, *, mask: Optional[jax.Array] = None): -@@ -885,6 +901,7 @@ class InstanceNorm(Module): +@@ -904,6 +920,7 @@ class InstanceNorm(Module): self.use_scale, self.bias_init, self.scale_init, @@ -381,10 +380,10 @@ index abfbfb5a..bab40243 100644 -- -2.25.1 +2.34.1 -From c945c2ff513282b4af2e956c9c09c784e6d48c44 Mon Sep 17 00:00:00 2001 +From bb04792f9d863f3d2f20e3635b28863f286a740f Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 2 Oct 2023 16:10:05 -0700 Subject: [PATCH 2/3] Added ConvTranspose sharding annotations (#3) @@ -395,10 +394,10 @@ Co-authored-by: sahilj 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/flax/linen/linear.py b/flax/linen/linear.py -index 4656abf9..187ab6f5 100644 +index 27e22325..d964eefb 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py -@@ -796,6 +796,21 @@ class ConvTranspose(Module): +@@ -922,6 +922,21 @@ class ConvTranspose(Module): bias_init: Initializer = initializers.zeros_init() transpose_kernel: bool = False @@ -420,7 +419,7 @@ index 4656abf9..187ab6f5 100644 @compact def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. -@@ -852,8 +867,9 @@ class ConvTranspose(Module): +@@ -986,8 +1001,9 @@ class ConvTranspose(Module): f'Shapes are: {self.mask.shape}, {kernel_shape}' ) @@ -432,7 +431,7 @@ index 4656abf9..187ab6f5 100644 ) if self.mask is not None: -@@ -864,8 +880,8 @@ class ConvTranspose(Module): +@@ -998,8 +1014,8 @@ class ConvTranspose(Module): padding_lax = 'VALID' if self.use_bias: @@ -444,10 +443,10 @@ index 4656abf9..187ab6f5 100644 else: bias = None -- -2.25.1 +2.34.1 -From 8b184f603e31feabb7580f1a969e101a7fe9e992 Mon Sep 17 00:00:00 2001 +From f3de082c72dc99b20f04ece2484ad2d6f5cceb28 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 1 Feb 2024 09:54:25 -0800 Subject: [PATCH 3/3] Add missing import @@ -459,7 +458,7 @@ Subject: [PATCH 3/3] Add missing import 3 files changed, 3 insertions(+) diff --git a/flax/linen/attention.py b/flax/linen/attention.py -index 689ce4da..b19d795e 100644 +index bcabd554..b594add9 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -39,6 +39,7 @@ from flax.typing import ( @@ -471,10 +470,10 @@ index 689ce4da..b19d795e 100644 def dot_product_attention_weights( diff --git a/flax/linen/linear.py b/flax/linen/linear.py -index 187ab6f5..759406ed 100644 +index d964eefb..e1c2751e 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py -@@ -47,6 +47,7 @@ from flax.typing import ( +@@ -49,6 +49,7 @@ from flax.typing import ( ConvGeneralDilatedT, PaddingLike, LaxPadding, @@ -483,7 +482,7 @@ index 187ab6f5..759406ed 100644 diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py -index bab40243..1e1169a0 100644 +index ed241d08..30a5b0fe 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -32,6 +32,7 @@ from flax.typing import ( @@ -495,5 +494,5 @@ index bab40243..1e1169a0 100644 field = dataclasses.field -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/paxml/PR-46.patch b/.github/container/patches/paxml/PR-46.patch index 2b3abe03d..06a607cc4 100644 --- a/.github/container/patches/paxml/PR-46.patch +++ b/.github/container/patches/paxml/PR-46.patch @@ -1,7 +1,7 @@ -From 37461f7b414c3c40c8730b5c2c9318329b8bc2d6 Mon Sep 17 00:00:00 2001 +From b2de6bbc35272ba74fddc30676835026e9c0e8c2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 18 Jul 2023 10:27:03 -0700 -Subject: [PATCH 1/9] add TE support +Subject: [PATCH 01/10] add TE support --- paxml/contrib/gpu/scripts_gpu/configs.py | 22 +- @@ -13,18 +13,18 @@ Subject: [PATCH 1/9] add TE support create mode 100644 paxml/contrib/gpu/scripts_gpu/te_helper.py diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py -index 2aceb30..934530e 100644 +index cb5ad52..16d4125 100644 --- a/paxml/contrib/gpu/scripts_gpu/configs.py +++ b/paxml/contrib/gpu/scripts_gpu/configs.py -@@ -23,6 +23,7 @@ from paxml.contrib.gpu.scripts_gpu.llama_utils import BaseLLaMA - from paxml.contrib.gpu.scripts_gpu.tasks import BoolQDataset +@@ -28,6 +28,7 @@ from paxml.contrib.gpu.scripts_gpu.tasks import BoolQDataset from paxml.contrib.gpu.scripts_gpu.tasks import LambadaDataset from paxml.contrib.gpu.scripts_gpu.tasks import PileUnsupervisedDataset + from paxml.tasks.lm.model_params import maybe_setup_moe_params +from paxml.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper from paxml.tasks.lm.params.c4 import TransformerLmSpmdAdam from paxml.tasks.lm.params.lm_cloud import SyntheticDataset from praxis import base_layer -@@ -111,7 +112,7 @@ class GPT126MBase(TransformerLmSpmdAdam): +@@ -116,7 +117,7 @@ class GPT126MBase(TransformerLmSpmdAdam): MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 @@ -33,7 +33,7 @@ index 2aceb30..934530e 100644 PERCORE_BATCH_SIZE = 4 NUM_LAYERS = 12 -@@ -166,10 +167,21 @@ class GPT126MBase(TransformerLmSpmdAdam): +@@ -171,10 +172,21 @@ class GPT126MBase(TransformerLmSpmdAdam): fdl.get_callable(stacked_p), transformers.StackedTransformerRepeated ): stacked_p = stacked_p.block @@ -58,7 +58,7 @@ index 2aceb30..934530e 100644 model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) -@@ -234,7 +246,7 @@ class GPT175BBase(GPT126MBase): +@@ -239,7 +251,7 @@ class GPT175BBase(GPT126MBase): # Known as MLP_DIM in t5x HIDDEN_DIMS = MODEL_DIMS * 4 # Defaults to MODEL_DIMS // NUM_HEADS. @@ -398,10 +398,10 @@ index 0000000..d44ca67 + finally: + pass diff --git a/paxml/main.py b/paxml/main.py -index d04332a..fd5228e 100644 +index ff3d89d..c6368d4 100644 --- a/paxml/main.py +++ b/paxml/main.py -@@ -50,6 +50,7 @@ from paxml import tf_data_service_lib +@@ -51,6 +51,7 @@ from paxml import tf_data_service_lib from paxml import train from paxml import trainer_lib from paxml import tuning_lib @@ -409,7 +409,7 @@ index d04332a..fd5228e 100644 from praxis import pax_fiddle from praxis import py_utils -@@ -510,39 +511,41 @@ def _main(argv: Sequence[str]) -> None: +@@ -512,39 +513,41 @@ def _main(argv: Sequence[str]) -> None: FLAGS.host_idx) ) @@ -654,13 +654,13 @@ index 6ec25e8..0342328 100644 train_state_partition_specs = ( -- -2.25.1 +2.34.1 -From 371b48043de072908aca80ba8b16f34008bd875c Mon Sep 17 00:00:00 2001 +From 33989ab49072915eaf70a127357ad87a2ab0185f Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 27 Sep 2023 10:46:53 +0800 -Subject: [PATCH 2/9] Adding dropout support when enabling TE. +Subject: [PATCH 02/10] Adding dropout support when enabling TE. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 10 ++++++++++ @@ -688,13 +688,13 @@ index d44ca67..2b9dba4 100644 assert self.packed_input == False assert len(self.moe_layers) == 0 -- -2.25.1 +2.34.1 -From 272e6352128962a9b2da10133737f3e1343bd36c Mon Sep 17 00:00:00 2001 +From 1993f5edd38a6523b12e56e533d22284a17ecc5a Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 24 Oct 2023 10:30:27 +0800 -Subject: [PATCH 3/9] Set deterministic=True for inference. +Subject: [PATCH 03/10] Set deterministic=True for inference. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 3 ++- @@ -715,13 +715,13 @@ index 2b9dba4..ef20305 100644 return x_out -- -2.25.1 +2.34.1 -From dfbf3a90cc0d93aa7d8e9c55c95ccf98c67f70bb Mon Sep 17 00:00:00 2001 +From 7edb0214176ec4169f2aa8b9aef2c89f41f1d0c4 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 2 Nov 2023 22:04:58 -0700 -Subject: [PATCH 4/9] Fix the excluded list for excluded_for_learner +Subject: [PATCH 04/10] Fix the excluded list for excluded_for_learner Signed-off-by: Reese Wang --- @@ -742,13 +742,13 @@ index 0342328..2e9bfd6 100644 vars_with_opt = tasks_lib.filter_vars_for_grad_or_opt( mdl_vars, excluded_for_learner -- -2.25.1 +2.34.1 -From 041456a8d4eb39350349101e263cb09f80b2b88c Mon Sep 17 00:00:00 2001 +From b76901e24b0f4bf1add943def4b5b590253527d7 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 11:21:53 +0800 -Subject: [PATCH 5/9] Adapting to TE/JAX/Custom_partitioning. +Subject: [PATCH 05/10] Adapting to TE/JAX/Custom_partitioning. --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 6 ++++-- @@ -779,13 +779,13 @@ index ef20305..fed1601 100644 finally: pass -- -2.25.1 +2.34.1 -From 7d976d6510d8d5f751fd566ed2703bfa2d0a89d0 Mon Sep 17 00:00:00 2001 +From 3d6ff0f9bb38217a7b04f48c36ee6892fa3c5b26 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 15:14:25 +0800 -Subject: [PATCH 6/9] Adding TE-compatiable PipelinedTransformer +Subject: [PATCH 06/10] Adding TE-compatiable PipelinedTransformer --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 109 +++++++++++++++++++++ @@ -946,13 +946,13 @@ index fed1601..5914e54 100644 def update_fp8_metas_if_needed(mdl_vars, grads): return TransformerEngineHelper.get_helper().update_fp8_metas_if_needed(mdl_vars, grads) -- -2.25.1 +2.34.1 -From a1bb3c7d24817e1a77219f1cdfdb70b34157fda2 Mon Sep 17 00:00:00 2001 +From b289c455efd107076001029f78e1215409dcbeda Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 8 Nov 2023 10:06:49 +0800 -Subject: [PATCH 7/9] Apply OWG to TE's FP8 meta +Subject: [PATCH 07/10] Apply OWG to TE's FP8 meta --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 59 ---------------------- @@ -1107,13 +1107,13 @@ index 2e9bfd6..270fb3d 100644 grads, states.opt_states[0], vars_with_opt, wps_with_opt ) -- -2.25.1 +2.34.1 -From 0d07668f96ea4e106388fbfdc47ae228918ec135 Mon Sep 17 00:00:00 2001 +From c91d86651daec782c4092898c7b9d01b45b76b56 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 14:43:17 +0800 -Subject: [PATCH 8/9] Remove Praxis related setup (Moving to Praxis TE/Patch) +Subject: [PATCH 08/10] Remove Praxis related setup (Moving to Praxis TE/Patch) --- paxml/contrib/gpu/scripts_gpu/configs.py | 9 - @@ -1121,10 +1121,10 @@ Subject: [PATCH 8/9] Remove Praxis related setup (Moving to Praxis TE/Patch) 2 files changed, 324 deletions(-) diff --git a/paxml/contrib/gpu/scripts_gpu/configs.py b/paxml/contrib/gpu/scripts_gpu/configs.py -index 934530e..fd01c74 100644 +index 16d4125..7d50a52 100644 --- a/paxml/contrib/gpu/scripts_gpu/configs.py +++ b/paxml/contrib/gpu/scripts_gpu/configs.py -@@ -173,15 +173,6 @@ class GPT126MBase(TransformerLmSpmdAdam): +@@ -178,15 +178,6 @@ class GPT126MBase(TransformerLmSpmdAdam): transformer_layer_p = stacked_p.transformer_layer_params_tpl transformer_layer_p.ln_tpl.reductions_in_fp32 = True transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True @@ -1499,13 +1499,13 @@ index fd482df..b271258 100644 @contextmanager def fp8_autocast(dp_mesh_axis="replica", tp_mesh_axis="mdl", fsdp_mesh_axis="data"): -- -2.25.1 +2.34.1 -From f94e783f56454d758b8cbab81f5ce756835fd065 Mon Sep 17 00:00:00 2001 +From b6add3ccf50bffcfa87052830f1dbc52bdccf525 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 14:51:14 +0800 -Subject: [PATCH 9/9] Fix missing DEFAULT_INIT_MUTABLE_LIST +Subject: [PATCH 09/10] Fix missing DEFAULT_INIT_MUTABLE_LIST --- paxml/contrib/gpu/scripts_gpu/te_helper.py | 4 ++++ @@ -1534,5 +1534,32 @@ index b271258..cbac7cf 100644 class TransformerEngineHelperBase: -- -2.25.1 +2.34.1 + + +From 12adf91f75fcbc5055ddaf93a25ed43ee99918ac Mon Sep 17 00:00:00 2001 +From: Hemil Desai +Date: Mon, 12 Feb 2024 10:22:15 -0800 +Subject: [PATCH 10/10] Revert mutable kwarg in abstract_init_with_metadata in + init checkpoint rule + +--- + paxml/tasks_lib.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py +index e475099..43e090c 100644 +--- a/paxml/tasks_lib.py ++++ b/paxml/tasks_lib.py +@@ -1787,7 +1787,7 @@ class SingleTask(base_task.BaseTask): + ) + # Initialize with a dummy seed + var_weight_hparams = ckpt_task.model.abstract_init_with_metadata( +- inputs_shape_dtype, mutable=DEFAULT_INIT_MUTABLE_LIST) ++ inputs_shape_dtype, extra_mutable_list=DEFAULT_INIT_MUTABLE_LIST) + ckpt_train_state = ckpt_task.create_train_state_padded_shapes( + var_weight_hparams) + train_state_pspecs = ckpt_task.create_train_state_partition_specs( +-- +2.34.1 diff --git a/.github/container/patches/praxis/PR-27.patch b/.github/container/patches/praxis/PR-27.patch index 516d2fe74..997d17d2e 100644 --- a/.github/container/patches/praxis/PR-27.patch +++ b/.github/container/patches/praxis/PR-27.patch @@ -34,5 +34,5 @@ index a35ce8b..52886bc 100644 self.add_summary('attention_mask', atten_mask) if self.attention_extra_logit is None: -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/praxis/PR-36.patch b/.github/container/patches/praxis/PR-36.patch index 5b298b11a..134d1596e 100644 --- a/.github/container/patches/praxis/PR-36.patch +++ b/.github/container/patches/praxis/PR-36.patch @@ -1,7 +1,7 @@ From 41488517eb6d95eb7943681e706c8804e6102c41 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 15 Nov 2023 11:38:27 +0800 -Subject: [PATCH 01/10] Adding TE support +Subject: [PATCH 01/15] Adding TE support --- praxis/contrib/gpu/scripts_gpu/te_helper.py | 176 ++++++++++++++++++++ @@ -247,13 +247,13 @@ index ab6cff3..c79dac9 100644 # Annotate the inputs before the pipeline to prevent unexpected # propagation from earlier layers. -- -2.25.1 +2.34.1 From ff1745796009cf1ec59f463f8e776c66f1286938 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 17 Nov 2023 15:21:06 +0800 -Subject: [PATCH 02/10] Fix missing vars wiht PP. +Subject: [PATCH 02/15] Fix missing vars wiht PP. --- praxis/contrib/gpu/scripts_gpu/te_helper.py | 34 ++++++++++++--------- @@ -358,13 +358,13 @@ index e3b2f7c..b31526e 100644 trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert), trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert), -- -2.25.1 +2.34.1 From 99e26aaf14131ca73501f08162be628b55c86a88 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 17 Jan 2024 02:36:31 -0800 -Subject: [PATCH 03/10] Add checkpoint_policy checker for fused attn + dropout +Subject: [PATCH 03/15] Add checkpoint_policy checker for fused attn + dropout Signed-off-by: Reese Wang --- @@ -476,13 +476,13 @@ index c79dac9..e076530 100644 repeats.Repeat, sub_tpl=self.block, -- -2.25.1 +2.34.1 From ab12f857404d84ed423e095d59e0bd336b94f151 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Mon, 15 Jan 2024 08:25:59 -0800 -Subject: [PATCH 04/10] Support more TE configurations +Subject: [PATCH 04/15] Support more TE configurations Signed-off-by: Reese Wang --- @@ -605,13 +605,13 @@ index 290b74c..9defcbd 100644 te_transformer_tpl.attention_dropout = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob te_transformer_tpl.hidden_dropout = stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob -- -2.25.1 +2.34.1 From dc632f0b2bf8da8724e4959360da034b3a7b4075 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Sun, 18 Feb 2024 01:14:41 -0800 -Subject: [PATCH 05/10] Change the gated activations orders +Subject: [PATCH 05/15] Change the gated activations orders Signed-off-by: Reese Wang --- @@ -643,13 +643,13 @@ index 9defcbd..1f8f6d6 100644 return te_tpl -- -2.25.1 +2.34.1 From f2e8560e6861a6dea981209b402b27dc6bc92022 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 28 Feb 2024 22:37:46 -0800 -Subject: [PATCH 06/10] Remove RoPE restriction from DPA module +Subject: [PATCH 06/15] Remove RoPE restriction from DPA module Signed-off-by: Reese Wang --- @@ -677,13 +677,13 @@ index 1f8f6d6..7d83c08 100644 assert attn_tpl.attention_extra_logit is None assert attn_tpl.ngrammer_tpl is None -- -2.25.1 +2.34.1 From 3d6b5c34a64939dadcc751cf2e989cec1affc648 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 28 Feb 2024 22:56:44 -0800 -Subject: [PATCH 07/10] Add rotary_pos_emb_group dispatch +Subject: [PATCH 07/15] Add rotary_pos_emb_group dispatch Signed-off-by: Reese Wang --- @@ -721,13 +721,13 @@ index 7d83c08..d187ba1 100644 raise ValueError(f'Unsupported {attn_tpl.cls=}') assert attn_tpl.atten_logit_cap <= 0., 'atten_logit_cap > 0. is not supported in TE' -- -2.25.1 +2.34.1 From 454f760a095562d995e7e9102f97c05158415312 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Sun, 3 Mar 2024 01:16:46 -0800 -Subject: [PATCH 08/10] Fix the missing .cls +Subject: [PATCH 08/15] Fix the missing .cls Signed-off-by: Reese Wang --- @@ -757,13 +757,13 @@ index d187ba1..733e9bf 100644 else: raise ValueError(f'Unsupported {attn_tpl.cls=}') -- -2.25.1 +2.34.1 From 0bd4a531a3d8e4ddafb5ed680092fd5636aa9f8e Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 2 Feb 2024 10:58:40 +0800 -Subject: [PATCH 09/10] Fixed the unexpected input sharding pattern when TE +Subject: [PATCH 09/15] Fixed the unexpected input sharding pattern when TE enabled. Signed-off-by: Ming-Xu Huang @@ -872,13 +872,13 @@ index d8720ae..235d63e 100644 [data_axis, None, None, mdl_axis] if training_optimized -- -2.25.1 +2.34.1 From e3e785cfedb4f350cfcb2c43b093f94288dbc846 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 7 Feb 2024 12:34:01 +0800 -Subject: [PATCH 10/10] Addind a comment to get_input_bld +Subject: [PATCH 10/15] Addind a comment to get_input_bld Signed-off-by: Ming-Xu Huang --- @@ -907,5 +907,318 @@ index ee0fc84..d0afc1a 100644 @staticmethod -- -2.25.1 +2.34.1 + + +From c45ae5b62a72cfce80f5c79b9f6f22ffb0dbfad1 Mon Sep 17 00:00:00 2001 +From: Ming Huang +Date: Thu, 14 Mar 2024 20:33:36 -0700 +Subject: [PATCH 11/15] Leverage the origianl input bld when ENABLE_SP=False + +Signed-off-by: Ming Huang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index d0afc1a..fa0375f 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -236,10 +236,10 @@ class TEInstalledHelper(TransformerEngineHelperBase): + return te_transformer_tpl + + @staticmethod +- def get_input_bld(_, batch_axes, mdl_axis): ++ def get_input_bld(original_bld, batch_axes, mdl_axis): + if ENABLE_TE_SP: + return [batch_axes, mdl_axis, None] +- return [batch_axes, None, None] ++ return original_bld + + @staticmethod + def get_bld_mapping_for_pipelined_transformer(_): +-- +2.34.1 + + +From 6a8bfddccb1916389b9a3206e6b3f9e142c30792 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Mon, 11 Mar 2024 07:20:08 -0700 +Subject: [PATCH 12/15] Use causal_padding instead of padding + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index fa0375f..e0f9fd1 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -208,7 +208,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert (transformer_layer_tpl.tr_fflayer_tpl.has_bias == + transformer_layer_tpl.tr_atten_tpl.use_bias), "TE only allows same bias settings." + te_transformer_tpl.use_bias = transformer_layer_tpl.tr_fflayer_tpl.has_bias +- te_transformer_tpl.self_attn_mask_type = 'causal' \ ++ te_transformer_tpl.self_attn_mask_type = 'padding_causal' \ + if stacked_transformer_obj.mask_self_attention else 'padding' + + te_transformer_tpl.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) +-- +2.34.1 + + +From 9c26058d7e285700a314c1306a9567ec497a9603 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Mon, 1 Apr 2024 01:07:12 -0700 +Subject: [PATCH 13/15] Fix rope convert + +Signed-off-by: Reese Wang +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index e0f9fd1..df6b9b8 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -177,7 +177,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert attn_tpl.attention_extra_logit is None + assert attn_tpl.ngrammer_tpl is None + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb +- if issubclass(attn_tpl.rotary_position_emb_tpl.cls, embedding_softmax.RotaryPositionalEmbedding): ++ if attn_tpl.rotary_position_emb_tpl.cls == embedding_softmax.RotaryPositionalEmbedding: + te_tpl.rotary_pos_emb_group_method = 'alternate' + elif issubclass(attn_tpl.cls, grouped_query_attention.GroupedQueryAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads +@@ -188,7 +188,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + elif issubclass(attn_tpl.cls, multi_query_attention.MultiQueryDotProductAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb +- if issubclass(attn_tpl.rotary_position_emb_tpl.cls, embedding_softmax.RotaryPositionalEmbedding): ++ if attn_tpl.rotary_position_emb_tpl.cls == embedding_softmax.RotaryPositionalEmbedding: + te_tpl.rotary_pos_emb_group_method = 'alternate' + else: + raise ValueError(f'Unsupported {attn_tpl.cls=}') +-- +2.34.1 + + +From bc21964acbddf8e6a9c5e1c09a17cbcdf24b6a50 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Mon, 1 Apr 2024 01:08:52 -0700 +Subject: [PATCH 14/15] Add test_te_helper.py + +Signed-off-by: Reese Wang +--- + .../contrib/gpu/scripts_gpu/test_te_helper.py | 86 +++++++++++++++++++ + 1 file changed, 86 insertions(+) + create mode 100644 praxis/contrib/gpu/scripts_gpu/test_te_helper.py + +diff --git a/praxis/contrib/gpu/scripts_gpu/test_te_helper.py b/praxis/contrib/gpu/scripts_gpu/test_te_helper.py +new file mode 100644 +index 0000000..c65a25d +--- /dev/null ++++ b/praxis/contrib/gpu/scripts_gpu/test_te_helper.py +@@ -0,0 +1,86 @@ ++from praxis import base_hyperparams ++from praxis import layers ++from praxis import pax_fiddle ++from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper ++from paxml.contrib.gpu.scripts_gpu.llama_utils import BaseLLaMA ++from paxml.contrib.gpu.scripts_gpu.configs import Synthetic5B ++from paxml.tasks.lm.params.lm_cloud import SyntheticDataset ++ ++import transformer_engine.jax.praxis as te_praxis ++ ++ ++class SyntheticLLaMA7B(BaseLLaMA, SyntheticDataset): ++ pass ++ ++ ++class TestGPT5B(): ++ ++ def test_te_tpl_convert(self): ++ task = Synthetic5B().task() ++ st_tpl = task.model.lm_tpl.stacked_transformer_tpl.block ++ te_tpl = TransformerEngineHelper().set_layer_params_to_stack_transformer(st_tpl, None, 0) ++ te_cls = base_hyperparams.instantiate(te_tpl) ++ assert te_cls.hidden_size == st_tpl.model_dims ++ assert te_cls.mlp_hidden_size == st_tpl.hidden_dims ++ assert te_cls.num_attention_heads == st_tpl.num_heads ++ assert te_cls.num_gqa_groups == te_cls.num_attention_heads ++ assert te_cls.layernorm_type == 'layernorm' ++ assert te_cls.layernorm_epsilon == 1e-5 ++ assert te_cls.zero_centered_gamma == True ++ assert te_cls.hidden_dropout == 0. ++ assert te_cls.hidden_dropout_dims == () ++ assert te_cls.attention_dropout == 0. ++ assert te_cls.intermediate_dropout == 0. ++ assert te_cls.intermediate_dropout_dims == () ++ assert te_cls.mlp_activations == ('gelu',) ++ assert te_cls.use_bias == True ++ assert te_cls.apply_residual_connection_post_layernorm == False ++ assert te_cls.output_layernorm == False ++ assert te_cls.float32_attention_logits == False ++ assert te_cls.layer_type == te_praxis.TransformerLayerType.ENCODER ++ assert te_cls.self_attn_mask_type == 'padding_causal' ++ assert te_cls.self_attn_bias_type == None ++ assert te_cls.enable_rotary_pos_emb == False ++ assert te_cls.rotary_pos_emb_windows == (1, 10000) ++ assert te_cls.enable_relative_embedding == False ++ assert te_cls.drop_path == 0. ++ assert te_cls.transpose_batch_sequence == False ++ assert te_cls.scale_attn_logits == True ++ assert te_cls.scaled_query_init == False ++ ++ ++class TestLLaMA7B(): ++ ++ def test_te_tpl_convert(self): ++ task = SyntheticLLaMA7B().task() ++ st_tpl = task.model.lm_tpl.stacked_transformer_tpl ++ te_tpl = TransformerEngineHelper().set_layer_params_to_stack_transformer(st_tpl, None, 0) ++ te_cls = base_hyperparams.instantiate(te_tpl) ++ assert te_cls.hidden_size == 4096 ++ assert te_cls.mlp_hidden_size == 16384 ++ assert te_cls.num_attention_heads == 32 ++ assert te_cls.num_gqa_groups == 32 ++ assert te_cls.layernorm_type == 'rmsnorm' ++ assert te_cls.layernorm_epsilon == 1e-5 ++ assert te_cls.zero_centered_gamma == False ++ assert te_cls.hidden_dropout == 0. ++ assert te_cls.hidden_dropout_dims == () ++ assert te_cls.attention_dropout == 0. ++ assert te_cls.intermediate_dropout == 0. ++ assert te_cls.intermediate_dropout_dims == () ++ assert te_cls.mlp_activations == ('linear', 'silu') ++ assert te_cls.use_bias == False ++ assert te_cls.apply_residual_connection_post_layernorm == False ++ assert te_cls.output_layernorm == False ++ assert te_cls.float32_attention_logits == False ++ assert te_cls.layer_type == te_praxis.TransformerLayerType.ENCODER ++ assert te_cls.self_attn_mask_type == 'padding_causal' ++ assert te_cls.self_attn_bias_type == None ++ assert te_cls.enable_rotary_pos_emb == True ++ assert te_cls.rotary_pos_emb_windows == (1, 10000) ++ assert te_cls.rotary_pos_emb_group_method == 'consecutive' ++ assert te_cls.enable_relative_embedding == False ++ assert te_cls.drop_path == 0. ++ assert te_cls.transpose_batch_sequence == False ++ assert te_cls.scale_attn_logits == True ++ assert te_cls.scaled_query_init == False +-- +2.34.1 + + +From 4c2f09318b642918adc655bdeb348550a41e2ae3 Mon Sep 17 00:00:00 2001 +From: Hemil Desai +Date: Thu, 23 May 2024 12:15:21 -0700 +Subject: [PATCH 15/15] Add LoRA support in TE (#11) + +--- + praxis/contrib/gpu/scripts_gpu/te_helper.py | 75 +++++++++++++++++++++ + 1 file changed, 75 insertions(+) + +diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py +index df6b9b8..d53b4e2 100644 +--- a/praxis/contrib/gpu/scripts_gpu/te_helper.py ++++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py +@@ -1,5 +1,7 @@ + import os + ++from absl import logging ++ + from praxis import base_layer + from praxis import pax_fiddle + from praxis import pytypes +@@ -9,6 +11,11 @@ from praxis.layers import activations + from praxis.layers import attentions, grouped_query_attention, multi_query_attention + from praxis.layers import embedding_softmax + from praxis.layers import normalizations ++from praxis.contrib.gpu.scripts_gpu.lora_layers import ( ++ LoraAttentionProjection, ++ LoraCombinedQKVProjection, ++ LoraLinear, ++) + + try: + import transformer_engine.jax as te +@@ -233,6 +240,74 @@ class TEInstalledHelper(TransformerEngineHelperBase): + assert len(stacked_transformer_obj.moe_layers) == 0 + assert stacked_transformer_obj.ngrammer_tpls is None + ++ def update_lora_te_tpl(te_tpl, transformer_layer_tpl): ++ lora_enabled = False ++ te_lora_scope = "none" ++ lora_rank = None ++ if ( ++ transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.__fn_or_cls__ ++ is LoraLinear ++ ): ++ lora_enabled = True ++ mlp_included_in_lora = True ++ current_rank = ( ++ transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.rank ++ ) ++ lora_rank = ( ++ current_rank if lora_rank is None else lora_rank & current_rank ++ ) ++ ++ attention_included_in_lora = False ++ if ( ++ hasattr(transformer_layer_tpl.tr_atten_tpl, "combined_qkv_proj_tpl") ++ and transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.__fn_or_cls__ ++ is LoraCombinedQKVProjection ++ ): ++ lora_enabled = True ++ attention_included_in_lora = True ++ current_rank = ( ++ transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.rank ++ ) ++ lora_rank = ( ++ current_rank if lora_rank is None else lora_rank & current_rank ++ ) ++ ++ if ( ++ hasattr(transformer_layer_tpl.tr_atten_tpl, "proj_tpl") ++ and transformer_layer_tpl.tr_atten_tpl.proj_tpl.__fn_or_cls__ ++ is LoraAttentionProjection ++ ): ++ lora_enabled = True ++ attention_included_in_lora = True ++ current_rank = transformer_layer_tpl.tr_atten_tpl.proj_tpl.rank ++ lora_rank = ( ++ current_rank if lora_rank is None else lora_rank & current_rank ++ ) ++ ++ if lora_enabled: ++ assert ( ++ lora_rank > 0 ++ ), "LoRA rank should be the same for all layers and greater than 0." ++ if attention_included_in_lora and mlp_included_in_lora: ++ te_lora_scope = "all" ++ elif attention_included_in_lora and not mlp_included_in_lora: ++ te_lora_scope = "exclude_mlp" ++ elif mlp_included_in_lora and not attention_included_in_lora: ++ te_lora_scope = "mlp" ++ ++ te_transformer_tpl.low_rank_adaptation_scope = te_lora_scope ++ te_transformer_tpl.low_rank_adaptation_dim = lora_rank ++ ++ return te_tpl ++ ++ try: ++ te_transformer_tpl = update_lora_te_tpl( ++ te_transformer_tpl, transformer_layer_tpl ++ ) ++ except Exception as e: ++ logging.warning(f"Unable to use LoRA with TE: {e}") ++ ++ + return te_transformer_tpl + + @staticmethod +-- +2.34.1 diff --git a/.github/container/patches/praxis/PR-74.patch b/.github/container/patches/praxis/PR-74.patch new file mode 100644 index 000000000..5f7a74357 --- /dev/null +++ b/.github/container/patches/praxis/PR-74.patch @@ -0,0 +1,227 @@ +From c003e294e8767c57d2c2839a03f9d2597568814a Mon Sep 17 00:00:00 2001 +From: Haixin Liu +Date: Fri, 7 Jun 2024 09:28:56 -0700 +Subject: [PATCH 1/3] use te dpa for grok mqa + +--- + praxis/layers/grok.py | 4 ++ + praxis/layers/multi_query_attention.py | 65 +++++++++++++++++--------- + 2 files changed, 46 insertions(+), 23 deletions(-) + +diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py +index 265a313..ad1ff5f 100644 +--- a/praxis/layers/grok.py ++++ b/praxis/layers/grok.py +@@ -59,6 +59,7 @@ def GrokStackedTransformerHParams( + combine_qkv=False, + bidirectional=False, + use_fp8=False, ++ use_te_dpa=True, + ) -> pax_fiddle.Config[transformers.StackedTransformer]: + """Common setup for Grok-1 Transformer layers. + +@@ -169,6 +170,7 @@ def GrokStackedTransformerHParams( + p.transformer_layer_params_tpl.tr_atten_tpl = pax_fiddle.Config( + multi_query_attention.MultiQueryDotProductAttention, + num_kv_heads=attention_num_groups, ++ use_te_dpa=use_te_dpa, + ) + tr_atten_tpl = p.transformer_layer_params_tpl.tr_atten_tpl + tr_atten_tpl.combine_qkv = False +@@ -228,6 +230,7 @@ def GrokUniTransformerLmHParams( + model_type=LanguageModelType.CAUSAL, + checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, + use_fp8=False, ++ use_te_dpa=True, + ) -> pax_fiddle.Config[transformer_models.TransformerLm]: + """Common setup for Grok-1 Decoder-only Transformer Model. + +@@ -331,6 +334,7 @@ def GrokUniTransformerLmHParams( + bidirectional=bidirectional, + moe_gating_embedding_level=moe_gating_embedding_level, + use_fp8=use_fp8, ++ use_te_dpa=use_te_dpa, + ) + num_blocks = num_transformer_layers + +diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py +index acd6959..33b7ab1 100644 +--- a/praxis/layers/multi_query_attention.py ++++ b/praxis/layers/multi_query_attention.py +@@ -31,7 +31,7 @@ from praxis.layers import attentions + from praxis.layers import base_ops + from praxis.layers import embedding_softmax + from praxis.layers import stochastics +- ++from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention + + WeightInit = base_layer.WeightInit + WeightHParams = base_layer.WeightHParams +@@ -209,6 +209,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): + pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + scale_query_by_dim_per_head: bool = False + chunked_attn_num_seq_split: int = 1 ++ use_te_dpa: bool = False + + # SPMD partition related params. + # +@@ -347,6 +348,20 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): + self.create_child('post', post_proj_p) + self.create_child('qk_einsum', self.qk_einsum_tpl.clone()) + self.create_child('pv_einsum', self.pv_einsum_tpl.clone()) ++ self.dpa_layer = TEDotProductAttention( ++ head_dim=dim_per_head, ++ num_attention_heads=self.num_heads, ++ num_gqa_groups=self.num_kv_heads, ++ attn_mask_type='causal', # 'causal' or 'padding' ++ attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' ++ attention_dropout=0., ++ dropout_rng_name='aqt', ++ dtype=jnp.bfloat16, ++ float32_logits=False, ++ qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' ++ scale_factor=1.0/math.sqrt(self.num_heads), ++ transpose_batch_sequence=False ++ ) + + def _shard_bnh(self, x: JTensor) -> JTensor: + """Shards tensors of shape [b, n, h]. +@@ -828,29 +843,33 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): + else: + key_proj = self._shard_blnh(key_proj) + value_proj = self._shard_blnh(value_proj) +- b, t, n, h = query_proj.shape +- _, s, nk, _ = key_proj.shape +- assert n % nk == 0 +- v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) +- if relative_bias is not None: +- v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) ++ if self.use_te_dpa: ++ atten_probs = None ++ encoded = self.dpa_layer(query_proj, key_proj, value_proj) + else: +- v_rb = None +- with self._context_for_kv_vmap(): +- encoded, atten_probs = jax.vmap( +- self._dot_atten, +- in_axes=(2, 2, 2, None, 1), +- out_axes=(2, 1), +- )( +- v_q, +- key_proj, +- value_proj, +- atten_mask, +- v_rb, +- ) +- encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) +- if atten_probs is not None: +- atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) ++ b, t, n, h = query_proj.shape ++ _, s, nk, _ = key_proj.shape ++ assert n % nk == 0 ++ v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) ++ if relative_bias is not None: ++ v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) ++ else: ++ v_rb = None ++ with self._context_for_kv_vmap(): ++ encoded, atten_probs = jax.vmap( ++ self._dot_atten, ++ in_axes=(2, 2, 2, None, 1), ++ out_axes=(2, 1), ++ )( ++ v_q, ++ key_proj, ++ value_proj, ++ atten_mask, ++ v_rb, ++ ) ++ encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) ++ if atten_probs is not None: ++ atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) + + # Post projection + encoded = self.post(encoded) +-- +2.34.1 + + +From 99bff322268ccd38d8b18d031eba66605d342172 Mon Sep 17 00:00:00 2001 +From: Haixin Liu +Date: Fri, 7 Jun 2024 10:37:20 -0700 +Subject: [PATCH 2/3] add doc string and warning + +--- + praxis/layers/multi_query_attention.py | 7 +++++-- + 1 file changed, 5 insertions(+), 2 deletions(-) + +diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py +index 33b7ab1..c11ffac 100644 +--- a/praxis/layers/multi_query_attention.py ++++ b/praxis/layers/multi_query_attention.py +@@ -17,7 +17,7 @@ + + import math + from typing import Callable, Mapping, Sequence +- ++from absl import logging + from flax import linen as nn + import jax + from jax import numpy as jnp +@@ -209,7 +209,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): + pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + scale_query_by_dim_per_head: bool = False + chunked_attn_num_seq_split: int = 1 +- use_te_dpa: bool = False ++ use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE + + # SPMD partition related params. + # +@@ -844,6 +844,9 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): + key_proj = self._shard_blnh(key_proj) + value_proj = self._shard_blnh(value_proj) + if self.use_te_dpa: ++ logging.warning( ++ 'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.' ++ ) + atten_probs = None + encoded = self.dpa_layer(query_proj, key_proj, value_proj) + else: +-- +2.34.1 + + +From 4eece5a90da2317994432c530ff37df42d6e617b Mon Sep 17 00:00:00 2001 +From: Haixin Liu +Date: Fri, 7 Jun 2024 16:47:13 -0700 +Subject: [PATCH 3/3] default to not use dpa in grok + +--- + praxis/layers/grok.py | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py +index ad1ff5f..2c1fd72 100644 +--- a/praxis/layers/grok.py ++++ b/praxis/layers/grok.py +@@ -59,7 +59,7 @@ def GrokStackedTransformerHParams( + combine_qkv=False, + bidirectional=False, + use_fp8=False, +- use_te_dpa=True, ++ use_te_dpa=False, + ) -> pax_fiddle.Config[transformers.StackedTransformer]: + """Common setup for Grok-1 Transformer layers. + +@@ -230,7 +230,7 @@ def GrokUniTransformerLmHParams( + model_type=LanguageModelType.CAUSAL, + checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, + use_fp8=False, +- use_te_dpa=True, ++ use_te_dpa=False, + ) -> pax_fiddle.Config[transformer_models.TransformerLm]: + """Common setup for Grok-1 Decoder-only Transformer Model. + +-- +2.34.1 + diff --git a/.github/container/patches/t5x/mirror-patch-dali-support.patch b/.github/container/patches/t5x/mirror-patch-dali-support.patch index b1c4e7ad1..912891d40 100644 --- a/.github/container/patches/t5x/mirror-patch-dali-support.patch +++ b/.github/container/patches/t5x/mirror-patch-dali-support.patch @@ -1,18 +1,18 @@ -From 01c2f2e6dedce10beb4ff7175a60f1526b026ae6 Mon Sep 17 00:00:00 2001 +From 600f935f45ad8f68770547e754c58f728d500461 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 16 May 2023 11:53:31 -0700 Subject: [PATCH 1/2] add support for DALI datasets --- - t5x/train.py | 88 +++++++++++++++++++++++------ + t5x/train.py | 90 ++++++++++++++++++++++++------ t5x/trainer.py | 146 +++++++++++++++++++++++++++++++++++++++++++++++++ - 2 files changed, 218 insertions(+), 16 deletions(-) + 2 files changed, 219 insertions(+), 17 deletions(-) diff --git a/t5x/train.py b/t5x/train.py -index e85d37c..e6027ce 100644 +index 8bbe866..2fd0da4 100644 --- a/t5x/train.py +++ b/t5x/train.py -@@ -116,10 +116,13 @@ def train( +@@ -117,10 +117,13 @@ def train( ], inference_evaluator_cls: utils.EvaluatorConstructor = seqio.Evaluator, get_dataset_fn: utils.GetDatasetCallable = utils.get_dataset, @@ -26,7 +26,7 @@ index e85d37c..e6027ce 100644 train_state_initializer_cls: Type[ utils.TrainStateInitializer ] = utils.TrainStateInitializer, -@@ -169,6 +172,8 @@ def train( +@@ -172,6 +175,8 @@ def train( evaluation, potentially with bound configuration args. get_dataset_fn: The callable use to get the train and train-eval datasets based on the DatasetConfig and shard information. @@ -35,7 +35,7 @@ index e85d37c..e6027ce 100644 concurrent_metrics: If True, allow metrics computation and logging to overlap with training. Will likely result in additional TPU memory usage. actions: A mapping of actions that runs after train, eval or infer_eval, to -@@ -179,8 +184,11 @@ def train( +@@ -182,8 +187,11 @@ def train( train_eval_get_dataset_fn: Optional callable use to get the train-eval datasets based on the DatasetConfig and shard information. If missing, it defaults to `utils.get_training_eval_datasets`. @@ -47,7 +47,7 @@ index e85d37c..e6027ce 100644 train_state_initializer_cls: t5x.utils.TrainStateInitializer class for initializing partitioned TrainState from checkpoints or scratch. use_orbax: if True, uses Orbax for checkpointing. Experimental feature. -@@ -299,12 +307,15 @@ def train( +@@ -270,12 +278,15 @@ def train( train_iter = get_dataset_fn( train_dataset_cfg, ds_shard_id, num_ds_shards, model.FEATURE_CONVERTER_CLS ) @@ -66,10 +66,10 @@ index e85d37c..e6027ce 100644 + data_layout=data_layout, + ) + - input_shapes = jax.tree_map( + input_shapes = jax.tree.map( lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec, -@@ -320,6 +331,12 @@ def train( +@@ -291,6 +302,12 @@ def train( eval_steps, model.FEATURE_CONVERTER_CLS, ) # type: Mapping[str, tf.data.Dataset] @@ -82,13 +82,14 @@ index e85d37c..e6027ce 100644 if not train_eval_datasets: logging.warning( 'No train_eval datasets loaded from config `train_eval_dataset_cfg`: ' -@@ -506,9 +523,16 @@ def train( +@@ -503,10 +520,17 @@ def train( def _run_training_eval(first_run: bool = False): if first_run: logging.info('Compiling training eval loop.') - trainer.compile_eval({ # pytype: disable=wrong-arg-types # jax-ndarray - task: utils.get_zeros_batch_like_dataset(ds) - for task, ds in train_eval_datasets.items() +- }) + if run_dali_eval: + trainer.compile_eval_dali({ + task: utils.get_zeros_batch_like_dataset(ds) @@ -99,10 +100,11 @@ index e85d37c..e6027ce 100644 + trainer.compile_eval({ # pytype: disable=wrong-arg-types # jax-ndarray + task: utils.get_zeros_batch_like_dataset(ds) + for task, ds in train_eval_datasets.items() - }) ++ }) logging.info('Computing training evaluation metrics.') eval_batch_iters = {} -@@ -518,13 +542,20 @@ def train( + for task, ds in train_eval_datasets.items(): +@@ -515,13 +539,20 @@ def train( else: eval_batch_iters[task] = ds @@ -130,7 +132,7 @@ index e85d37c..e6027ce 100644 def _run_inference_eval(): """Run prediction based inference eval.""" -@@ -553,6 +584,19 @@ def train( +@@ -550,6 +581,19 @@ def train( if train_eval_datasets: logging.info('Running training eval before training.') _run_training_eval(first_run=True) @@ -150,7 +152,7 @@ index e85d37c..e6027ce 100644 if evaluator is not None: logging.info('Running inference eval before training.') _run_inference_eval() -@@ -793,6 +837,18 @@ def train( +@@ -796,6 +840,18 @@ def train( # Maybe less if final step < period. first_run = step_offset // eval_period <= 1 _run_training_eval(first_run and not run_eval_before_training) @@ -170,7 +172,7 @@ index e85d37c..e6027ce 100644 # Inference Evaluation (i.e., with decoding or scoring). if is_eval_epoch and evaluator is not None: diff --git a/t5x/trainer.py b/t5x/trainer.py -index 3d592d8..7a321cd 100644 +index 965bd09..5e5ec34 100644 --- a/t5x/trainer.py +++ b/t5x/trainer.py @@ -35,6 +35,7 @@ import clu.metrics @@ -355,10 +357,10 @@ index 3d592d8..7a321cd 100644 def _warn_action_not_run(action, task, metric): logging.warning( -- -2.25.1 +2.34.1 -From 79d36a39921b83271ff75748a211185884744f8b Mon Sep 17 00:00:00 2001 +From 16e3abd9d5954ea406bba245acf84a850d787c8a Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 1 Nov 2023 11:17:53 -0700 Subject: [PATCH 2/2] fix bug in rebase @@ -368,10 +370,10 @@ Subject: [PATCH 2/2] fix bug in rebase 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/train.py b/t5x/train.py -index e6027ce..ec1c2fa 100644 +index 2fd0da4..a36b909 100644 --- a/t5x/train.py +++ b/t5x/train.py -@@ -309,7 +309,7 @@ def train( +@@ -280,7 +280,7 @@ def train( ) if prepare_train_iter_fn: @@ -381,5 +383,5 @@ index e6027ce..ec1c2fa 100644 checkpoint_cfg=checkpoint_cfg, partitioner=partitioner, -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch index 9250d85e0..4b893d84a 100644 --- a/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch +++ b/.github/container/patches/t5x/mirror-patch-partial-checkpoint-restore.patch @@ -25,5 +25,5 @@ index 61682ed..77e0860 100644 ] # 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set. -- -2.25.1 +2.34.1 diff --git a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch index 75f7330c3..dd66d9b8b 100644 --- a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch +++ b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch @@ -1,7 +1,7 @@ From 0bc3617a1befcd249f6a95584dba9634bd6b879c Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 24 Apr 2023 10:18:29 -0700 -Subject: [PATCH 01/18] Added transformer engine support and GPU optimizations +Subject: [PATCH 01/19] Added transformer engine support and GPU optimizations Co-authored-by: Sahil Jain Co-authored-by: Terry Kong @@ -2465,13 +2465,13 @@ index 965bd09..cebb815 100644 metrics["learning_rate"] = clu.metrics.Average.from_model_output( jnp.asarray([learning_rate]) -- -2.25.1 +2.34.1 From dcbbb37dea8cef5b0c2d798b57cdfd2d8600013d Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 11 Jul 2023 12:10:33 -0700 -Subject: [PATCH 02/18] UNINSTALL_TE in fine-tuning scripts now defaults to +Subject: [PATCH 02/19] UNINSTALL_TE in fine-tuning scripts now defaults to no-action --- @@ -2500,13 +2500,13 @@ index 388d2ec..135ecf6 100755 # Global batch size BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) -- -2.25.1 +2.34.1 From db6fc5519e240eb74ce39668047327da323f323a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 12 Jul 2023 20:26:28 -0700 -Subject: [PATCH 03/18] remove use_gda from LegacyCheckpointManager in train.py +Subject: [PATCH 03/19] remove use_gda from LegacyCheckpointManager in train.py for fp8 --- @@ -2528,13 +2528,13 @@ index e70bcfb..0e11482 100644 # Start warming up the input pipeline in the background. This must happen # after input pipeline checkpoints were restored. -- -2.25.1 +2.34.1 From a39a08ec58e47976facda6ff8e340d5d1cec1f97 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 18 Jul 2023 15:55:01 -0700 -Subject: [PATCH 04/18] Allow singlenode scripts to tee to stdout for better +Subject: [PATCH 04/19] Allow singlenode scripts to tee to stdout for better indication of training status --- @@ -2565,13 +2565,13 @@ index def1a1a..0d12f30 100755 + 2>&1 | tee \ ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log -- -2.25.1 +2.34.1 From 39e637f3bb12dbe75dada4383f1d30805b946dc4 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 14 Jul 2023 05:00:58 -0700 -Subject: [PATCH 05/18] Explicit specify self_attn_mask_type +Subject: [PATCH 05/19] Explicit specify self_attn_mask_type --- t5x/te_helper.py | 10 ++++++++-- @@ -2606,13 +2606,13 @@ index fb5f48f..f3750ca 100644 class TransformerEngineHelper(TransformerEngineHelperBase): -- -2.25.1 +2.34.1 From d016f83835aabd2966a50f37126138c885b5a3b0 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 3 Aug 2023 14:25:22 -0700 -Subject: [PATCH 06/18] Disables check for packing by the te_helper util since +Subject: [PATCH 06/19] Disables check for packing by the te_helper util since not all dataset configs use packing (CV/Multimodal) --- @@ -2633,13 +2633,13 @@ index f3750ca..f585752 100644 "Transformer Engine does not support dataset.packing, please turn it off." -- -2.25.1 +2.34.1 From 83a2b20dfc6e6df59887cd44000c5f84608d0400 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sat, 26 Aug 2023 16:13:11 -0700 -Subject: [PATCH 07/18] Corrected T5x large baselines +Subject: [PATCH 07/19] Corrected T5x large baselines Updated T5x-large MNLI and SQUAD baselines --- @@ -2660,13 +2660,13 @@ index a9974e1..660df3a 100644 | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) -- -2.25.1 +2.34.1 From 5944f07c1924783a90e1a61915a9d8f2ea7b2215 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 8 Sep 2023 15:09:08 -0700 -Subject: [PATCH 08/18] Add t5-large FP8 logs +Subject: [PATCH 08/19] Add t5-large FP8 logs --- docs/usage/gpu-usage.md | 2 +- @@ -2686,13 +2686,13 @@ index 660df3a..c31094d 100644 | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) -- -2.25.1 +2.34.1 From 2d2fbe8e857990a336f9d949ad7e8d82351b1633 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 14:26:09 +0800 -Subject: [PATCH 09/18] Fix missing fp8_meta_collection in the eval stage. +Subject: [PATCH 09/19] Fix missing fp8_meta_collection in the eval stage. --- t5x/models.py | 2 +- @@ -2712,13 +2712,13 @@ index 5891c08..e13b810 100644 enable_dropout=False, method=self.module.encode, -- -2.25.1 +2.34.1 From 4a86f76013173545d1299feaeb0e3383e890bee9 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 14:48:40 +0800 -Subject: [PATCH 10/18] Remove redundant code. +Subject: [PATCH 10/19] Remove redundant code. --- t5x/models.py | 17 ----------------- @@ -2753,13 +2753,13 @@ index e13b810..cc3348f 100644 # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop # after the prompt by matching to `output_vocabulary.eos_id`. -- -2.25.1 +2.34.1 From 7b878db04ad1e424bf6ca330c3e6fc8f9e0c1c3b Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 15:05:40 +0800 -Subject: [PATCH 11/18] Fix deprecating warning about TE. +Subject: [PATCH 11/19] Fix deprecating warning about TE. --- t5x/te_helper.py | 8 ++++---- @@ -2805,13 +2805,13 @@ index f585752..568f596 100644 name=name) -- -2.25.1 +2.34.1 From 4c604770bf02d3ff5fab6b1e795c5b6074032c7e Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 27 Oct 2023 09:08:10 -0700 -Subject: [PATCH 12/18] Updates TE api from te.extend_* to te.flax.extend_* +Subject: [PATCH 12/19] Updates TE api from te.extend_* to te.flax.extend_* (#7) Co-authored-by: NVIDIA @@ -2833,13 +2833,13 @@ index 568f596..05c5f6b 100644 @staticmethod def update_fp8_metas(grad_accum, flax_mutables): -- -2.25.1 +2.34.1 From a3f2ab9fdf0681e10ffbcb2588274941e52700ba Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 31 Oct 2023 21:52:22 -0700 -Subject: [PATCH 13/18] Adds ENABLE_TE env var and renames TEConfig.enabled -> +Subject: [PATCH 13/19] Adds ENABLE_TE env var and renames TEConfig.enabled -> TEConfig.enable_fp8 (#8) * Allows ENABLE_TE env var to control whether TE code path is invoked @@ -2996,13 +2996,13 @@ index 05c5f6b..7657c52 100644 return TENotInstalledHelper -- -2.25.1 +2.34.1 From 4abe3e591574b0b2e5bbcdcd5e96df4f8d4367de Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 10:57:34 +0800 -Subject: [PATCH 14/18] Adapting to TE/JAX/Custom_partitioning. +Subject: [PATCH 14/19] Adapting to TE/JAX/Custom_partitioning. --- t5x/te_helper.py | 5 ++--- @@ -3032,13 +3032,13 @@ index 7657c52..b064d2b 100644 @staticmethod -- -2.25.1 +2.34.1 From bfa6313ce0bc9b67f6d1b48fca460e1bfea16670 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 22 Nov 2023 13:53:49 +0800 -Subject: [PATCH 15/18] Running Partitioner.compile within Mesh context-manager +Subject: [PATCH 15/19] Running Partitioner.compile within Mesh context-manager --- t5x/partitioning.py | 9 ++++++++- @@ -3072,13 +3072,13 @@ index 847dc24..63873cf 100644 class PjitPartitioner(BasePjitPartitioner): -- -2.25.1 +2.34.1 From b4dbfdea8313b0f098d7583d80dee4eea6180cdb Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 14 Nov 2023 23:42:35 -0800 -Subject: [PATCH 16/18] Updates multiprocessing scripts to use SLURM output +Subject: [PATCH 16/19] Updates multiprocessing scripts to use SLURM output variables instead of input variables (#9) * Update multiprocess scripts @@ -3549,13 +3549,13 @@ index d083540..56919a5 100755 -set +x +echo Finished -- -2.25.1 +2.34.1 From 189868b45483c3796035ebdd23a173feabe2a4ce Mon Sep 17 00:00:00 2001 From: ashors1 <71393111+ashors1@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:56:45 -0800 -Subject: [PATCH 17/18] Force initial flax mutables to be a frozen dict (#11) +Subject: [PATCH 17/19] Force initial flax mutables to be a frozen dict (#11) --- t5x/trainer.py | 2 +- @@ -3575,13 +3575,13 @@ index cebb815..8910ce8 100644 if num_microbatches is None or num_microbatches <= 1: -- -2.25.1 +2.34.1 From 06be7c2e50535630bebe023f6bca922ccfe93448 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 28 Dec 2023 07:24:33 -0800 -Subject: [PATCH 18/18] update rng dtype in predict_batch +Subject: [PATCH 18/19] update rng dtype in predict_batch --- t5x/models.py | 2 +- @@ -3601,5 +3601,32 @@ index cc3348f..eb7bd37 100644 """Predicts a batch of outputs from the model. -- -2.25.1 +2.34.1 + + +From 339b03461fce0caca5e838139289888086c46d15 Mon Sep 17 00:00:00 2001 +From: Reese Wang +Date: Fri, 22 Mar 2024 23:49:37 -0700 +Subject: [PATCH 19/19] Change decoder attn mask type to padding_causal + +Signed-off-by: Reese Wang +--- + t5x/te_helper.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/t5x/te_helper.py b/t5x/te_helper.py +index b064d2b..9410aa4 100644 +--- a/t5x/te_helper.py ++++ b/t5x/te_helper.py +@@ -241,7 +241,7 @@ class TEInstalledHelper(TransformerEngineHelperBase): + relative_embedding=relative_embedding, + dtype=config.dtype, + layer_type=te.flax.TransformerLayerType.DECODER, +- self_attn_mask_type='causal', ++ self_attn_mask_type='padding_causal', + name=name) + + +-- +2.34.1