Skip to content

Commit

Permalink
Fix steps/sec computation in Pax tests (#581)
Browse files Browse the repository at this point in the history
There are currently two issues affecting the reported step times in pax:
1. First `steps/sec` value includes compilation time
2. `steps/sec` computation includes eval step compilation and eval step
time

To fix (1), we exclude the first `steps/sec` value from the average step
time. To fix (2), we disable eval.

---------

Co-authored-by: Yu-Hang "Maxin" Tang <Tang.Maxin@gmail.com>
  • Loading branch information
ashors1 and yhtang authored Feb 28, 2024
1 parent baffeee commit 6378796
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 91 deletions.
42 changes: 21 additions & 21 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
jax:
url: https://github.com/google/jax.git
tracking_ref: main
latest_verified_commit: f2613387bd7421bd1fedbed15257b9841717b34a
latest_verified_commit: 75cdef7626b92b8b6563ea68ae4747fd6994db2e
mode: git-clone
xla:
url: https://github.com/openxla/xla.git
tracking_ref: main
latest_verified_commit: 5ea80986813c3eef73c4d53ee881c203839f2c16
latest_verified_commit: 831e9cef85493ff7ee2e24fd4cc64377d682aecc
mode: git-clone
flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: d9585e0a6ba6c4a4ebc93d0707add573420703df
latest_verified_commit: aaf130c90eb46160a3234c258a48bf1b932d7829
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
transformer-engine:
url: https://github.com/NVIDIA/TransformerEngine.git
tracking_ref: main
latest_verified_commit: d68028c872186cd3df604e2ee2f60e09784d955c
latest_verified_commit: 9b2fed514ea419141146f843ab2c84b22b86bfd7
mode: git-clone
t5x:
url: https://github.com/google-research/t5x.git
mirror_url: https://github.com/nvjax-svc-0/t5x.git
tracking_ref: main
latest_verified_commit: cd94b76f53244914660e83a8f813d8e809ba05e4
latest_verified_commit: ecb126e1f5c2aea648f39869d4e69fb4374a4868
mode: git-clone
patches:
mirror/patch/partial-checkpoint-restore: file://patches/t5x/mirror-patch-partial-checkpoint-restore.patch # pull/1392/head # https://github.com/google-research/t5x/pull/1392: Add support for partial checkpoint restore
Expand All @@ -35,7 +35,7 @@ paxml:
url: https://github.com/google/paxml.git
mirror_url: https://github.com/nvjax-svc-0/paxml.git
tracking_ref: main
latest_verified_commit: e6cbcb9990a14d6a166166d822dc8e929b6bab0d
latest_verified_commit: e5bebb78635d042b1212703d3a31d81ca61ca2fa
mode: git-clone
patches:
pull/46/head: file://patches/paxml/PR-46.patch # adds Transformer Engine support
Expand All @@ -44,7 +44,7 @@ praxis:
url: https://github.com/google/praxis.git
mirror_url: https://github.com/nvjax-svc-0/praxis.git
tracking_ref: main
latest_verified_commit: d2d3bb89770de2b3af666981295c23defb38bd2f
latest_verified_commit: 1358698fad63deb8306ef007331f1811776ad08f
mode: git-clone
patches:
pull/27/head: file://patches/praxis/PR-27.patch # This PR allows XLA:GPU to detect the MHA pattern more easily to call fused kernels from cublas.
Expand All @@ -54,7 +54,7 @@ lingvo:
# Used only in ARM pax builds
url: https://github.com/tensorflow/lingvo.git
tracking_ref: master
latest_verified_commit: 8bb7652184c6c895def36e9fb242a9751705fa0e
latest_verified_commit: 5bbe38c046519b86fa5c0488f813ffbf8b467d7e
mode: git-clone
tensorflow-text:
# Used only in ARM pax and t5x builds
Expand All @@ -69,18 +69,18 @@ pydantic:
fiddle:
url: https://github.com/google/fiddle.git
tracking_ref: main
latest_verified_commit: 7a12009bf9d07652759e8554fa93135b2d63fd41
latest_verified_commit: 2a17618c56eb99aa58aa898ae12cbac7cf5c3b30
mode: pip-vcs
# Used by t5x
airio:
url: https://github.com/google/airio.git
tracking_ref: main
latest_verified_commit: 899e6ec7cb1aa4239cb96ece7409d3cf19f6e6e4
latest_verified_commit: e4c682e691354d75a6bea521cd61709b1ab81d34
mode: pip-vcs
clu:
url: https://github.com/google/CommonLoopUtils.git
tracking_ref: main
latest_verified_commit: 1368e52d0876dd0c90894793e8e9e97fc6f98adc
latest_verified_commit: eed40a1facd526df0e0faa192525f357a3321dca
mode: pip-vcs
dllogger:
url: https://github.com/NVIDIA/dllogger.git
Expand All @@ -95,43 +95,43 @@ jestimator:
optax:
url: https://github.com/google-deepmind/optax.git
tracking_ref: main
latest_verified_commit: c86b9a99eea69c5f2bb81bd526a579b3f6b5f4d0
latest_verified_commit: 623609c7a77a19d48b021cbc300262308846317e
mode: pip-vcs
seqio:
url: https://github.com/google/seqio.git
tracking_ref: main
latest_verified_commit: 763998620c2071b3b05b3bbf94cf02305746ab9f
latest_verified_commit: e31af8c1a11f749edeac512f34d148b9933f863f
mode: pip-vcs
# used by Pallas
openxla-triton:
url: https://github.com/openxla/triton.git
tracking_ref: llvm-head
latest_verified_commit: cl601105910
latest_verified_commit: 3764c21d906d497256507bc25fa4135bb472cc13
mode: git-clone
jax-triton:
url: https://github.com/jax-ml/jax-triton.git
tracking_ref: main
latest_verified_commit: 28ad4766271a181587e6e17e17de7f729c1a03b5
latest_verified_commit: 708d3e8afe13b52e4191ad3b677c6f1238677c9e
mode: git-clone
maxtext:
url: https://github.com/google/maxtext.git
tracking_ref: main
latest_verified_commit: c94564a66a6b13b4da6248847f3799b463283454
latest_verified_commit: 5420bc5753fec4b3a811664cdb58f3c9e98d35fb
mode: git-clone
levanter:
url: https://github.com/stanford-crfm/levanter.git
tracking_ref: main
latest_verified_commit: c1e7b24aa47e073af0aa48606581f817b59e1b45
latest_verified_commit: 94a432e7999ae016645bc72e9dda55e724d0f834
mode: git-clone
haliax:
url: https://github.com/stanford-crfm/haliax.git
tracking_ref: main
latest_verified_commit: ae5f4ce74a429a9ae45e350099f2ecc0cd95004c
latest_verified_commit: 690623131e107972ec2ec67d6183c77649d4b7e0
mode: git-clone
mujoco:
url: https://github.com/google-deepmind/mujoco.git
tracking_ref: main
latest_verified_commit: ced37ffbb237584512311b041bce3124e3b2cc2a
latest_verified_commit: c6a41fbfe64ee7b2680a6bde90200ca660d08c2a
mode: git-clone
grain:
# Used only in ARM t5x builds
Expand All @@ -142,7 +142,7 @@ grain:
mujoco-mpc:
url: https://github.com/google-deepmind/mujoco_mpc.git
tracking_ref: main
latest_verified_commit: 73633d7da1900c428a7315d2ffe1120c5393a7d8
latest_verified_commit: 50a0159cbc70b38a7fee425b8bf5edbc04a1b62e
mode: git-clone
language-to-reward-2023:
url: https://github.com/google-deepmind/language_to_reward_2023.git
Expand All @@ -152,5 +152,5 @@ language-to-reward-2023:
mlperf-logging:
url: https://github.com/mlcommons/logging.git
tracking_ref: master
latest_verified_commit: 38709131757a786d7f66150243b49eb0365324d7
latest_verified_commit: c7b23b3d7aa1055c60e6513edebd138e1a597c97
mode: pip-vcs
52 changes: 26 additions & 26 deletions .github/container/patches/flax/PR-3340.patch
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
From a563a57ca081e87b537bf724e2e29b138a14bb3f Mon Sep 17 00:00:00 2001
From d748ab4447dbb82ea9317f71211a3bbd9ba4207f Mon Sep 17 00:00:00 2001
From: ashors1 <ashors@nvidia.com>
Date: Fri, 2 Jun 2023 15:01:21 -0700
Subject: [PATCH 1/3] add t5x sharding annotations to flax layers
Expand All @@ -10,7 +10,7 @@ Subject: [PATCH 1/3] add t5x sharding annotations to flax layers
3 files changed, 79 insertions(+), 21 deletions(-)

diff --git a/flax/linen/attention.py b/flax/linen/attention.py
index e10a02d1..2827119f 100644
index efcf2b78..689ce4da 100644
--- a/flax/linen/attention.py
+++ b/flax/linen/attention.py
@@ -30,6 +30,7 @@ from flax.linen.linear import (
Expand Down Expand Up @@ -68,7 +68,7 @@ index e10a02d1..2827119f 100644
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
@@ -467,14 +485,14 @@ class MultiHeadDotProductAttention(Module):
@@ -477,14 +495,14 @@ class MultiHeadDotProductAttention(Module):
if self.decode:
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
Expand All @@ -89,7 +89,7 @@ index e10a02d1..2827119f 100644
)
if is_initialized:
(
@@ -568,6 +586,8 @@ class MultiHeadDotProductAttention(Module):
@@ -580,6 +598,8 @@ class MultiHeadDotProductAttention(Module):
dot_general=self.out_dot_general,
dot_general_cls=self.out_dot_general_cls,
name='out', # type: ignore[call-arg]
Expand Down Expand Up @@ -235,7 +235,7 @@ index 36365ea1..4656abf9 100644
bias = None

diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py
index cec8f508..e815f955 100644
index abfbfb5a..bab40243 100644
--- a/flax/linen/normalization.py
+++ b/flax/linen/normalization.py
@@ -24,6 +24,7 @@ from jax import lax
Expand Down Expand Up @@ -299,80 +299,80 @@ index cec8f508..e815f955 100644
+ pjit_axis_name: Tuple[str, ...] = None

@compact
def __call__(self, x, use_running_average: Optional[bool] = None, mask=None):
@@ -371,6 +378,7 @@ class BatchNorm(Module):
def __call__(
@@ -377,6 +384,7 @@ class BatchNorm(Module):
self.use_scale,
self.bias_init,
self.scale_init,
+ self.pjit_axis_name,
)


@@ -433,6 +441,7 @@ class LayerNorm(Module):
@@ -439,6 +447,7 @@ class LayerNorm(Module):
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
+ pjit_axis_names: A tuple of axis names.
"""

epsilon: float = 1e-6
@@ -447,6 +456,7 @@ class LayerNorm(Module):
@@ -453,6 +462,7 @@ class LayerNorm(Module):
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
+ pjit_axis_name: Tuple[str, ...] = None

@compact
def __call__(self, x, mask=None):
@@ -484,6 +494,7 @@ class LayerNorm(Module):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
@@ -490,6 +500,7 @@ class LayerNorm(Module):
self.use_scale,
self.bias_init,
self.scale_init,
+ self.pjit_axis_name,
)


@@ -530,6 +541,7 @@ class RMSNorm(Module):
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
examples on the first two and last two devices. See ``jax.lax.psum`` for
@@ -538,6 +549,7 @@ class RMSNorm(Module):
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
+ pjit_axis_names: A tuple of axis names.
"""

epsilon: float = 1e-6
@@ -541,6 +553,7 @@ class RMSNorm(Module):
feature_axes: Axes = -1
@@ -550,6 +562,7 @@ class RMSNorm(Module):
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
+ pjit_axis_name: Tuple[str, ...] = None

@compact
def __call__(self, x, mask=None):
@@ -578,6 +591,7 @@ class RMSNorm(Module):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
@@ -588,6 +601,7 @@ class RMSNorm(Module):
self.use_scale,
initializers.zeros,
self.scale_init,
+ self.pjit_axis_name,
)


@@ -647,6 +661,7 @@ class GroupNorm(Module):
@@ -657,6 +671,7 @@ class GroupNorm(Module):
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
+ pjit_axis_names: A tuple of axis names.
"""

num_groups: Optional[int] = 32
@@ -662,6 +677,7 @@ class GroupNorm(Module):
@@ -672,6 +687,7 @@ class GroupNorm(Module):
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True
+ pjit_axis_name: Tuple[str, ...] = None

@compact
def __call__(self, x, mask=None):
@@ -875,6 +891,7 @@ class InstanceNorm(Module):
def __call__(self, x, *, mask: Optional[jax.Array] = None):
@@ -885,6 +901,7 @@ class InstanceNorm(Module):
self.use_scale,
self.bias_init,
self.scale_init,
Expand All @@ -384,7 +384,7 @@ index cec8f508..e815f955 100644
2.25.1


From 4ddde2bc878d9ff841f34e3826bbd1bce705766d Mon Sep 17 00:00:00 2001
From c945c2ff513282b4af2e956c9c09c784e6d48c44 Mon Sep 17 00:00:00 2001
From: Terry Kong <terrycurtiskong@gmail.com>
Date: Mon, 2 Oct 2023 16:10:05 -0700
Subject: [PATCH 2/3] Added ConvTranspose sharding annotations (#3)
Expand Down Expand Up @@ -447,7 +447,7 @@ index 4656abf9..187ab6f5 100644
2.25.1


From 9e68f4758d8a87553f3989e0a014974813e12390 Mon Sep 17 00:00:00 2001
From 8b184f603e31feabb7580f1a969e101a7fe9e992 Mon Sep 17 00:00:00 2001
From: ashors1 <ashors@nvidia.com>
Date: Thu, 1 Feb 2024 09:54:25 -0800
Subject: [PATCH 3/3] Add missing import
Expand All @@ -459,7 +459,7 @@ Subject: [PATCH 3/3] Add missing import
3 files changed, 3 insertions(+)

diff --git a/flax/linen/attention.py b/flax/linen/attention.py
index 2827119f..517ff2dc 100644
index 689ce4da..b19d795e 100644
--- a/flax/linen/attention.py
+++ b/flax/linen/attention.py
@@ -39,6 +39,7 @@ from flax.typing import (
Expand All @@ -483,7 +483,7 @@ index 187ab6f5..759406ed 100644


diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py
index e815f955..9e59e2be 100644
index bab40243..1e1169a0 100644
--- a/flax/linen/normalization.py
+++ b/flax/linen/normalization.py
@@ -32,6 +32,7 @@ from flax.typing import (
Expand Down
Loading

0 comments on commit 6378796

Please sign in to comment.