Skip to content

Commit

Permalink
Merge branch 'main' into add-cuda-image-arg
Browse files Browse the repository at this point in the history
  • Loading branch information
yhtang authored Sep 17, 2024
2 parents 1836b70 + f116054 commit 99d059b
Show file tree
Hide file tree
Showing 37 changed files with 1,124 additions and 202 deletions.
4 changes: 3 additions & 1 deletion .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH}
###############################################################################

ADD install-nsight.sh /usr/local/bin
ADD nsys-2024.5-tid-export.patch /opt/nvidia
RUN install-nsight.sh

###############################################################################
Expand Down Expand Up @@ -180,6 +181,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH}
ADD install-nccl-sanity-check.sh /usr/local/bin
ADD nccl-sanity-check.cu /opt
RUN install-nccl-sanity-check.sh
ADD jax-nccl-test parallel-launch /usr/local/bin

###############################################################################
## Add the systemcheck to the entrypoint.
Expand All @@ -203,7 +205,7 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/

ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
ADD requirements-nsys-jax.in /opt/pip-tools.d/
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/

###############################################################################
Expand Down
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
export NVTE_BUILD_THREADS_PER_JOB=8
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF
Expand Down
3 changes: 3 additions & 0 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
chmod 755 $BUILD_PATH_JAXLIB/*

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
19 changes: 5 additions & 14 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,15 @@ export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles

apt-get update
# TODO: revert to nsight-systems-cli instead of explicitly pinning
apt-get install -y nsight-compute nsight-systems-cli-2024.4.1
apt-get install -y nsight-compute nsight-systems-cli
apt-get clean

rm -rf /var/lib/apt/lists/*

# "Wrong event order has been detected when adding events to the collection"
# workaround during nsys report post-processing with 2024.1.1 and CUDA 12.3
NSYS202411=/opt/nvidia/nsight-systems-cli/2024.1.1
if [[ "${UBUNTU_ARCH}" == "amd64" && -d "${NSYS202411}" ]]; then
LIBCUPTI123=/opt/nvidia/nsight-compute/2023.3.0/host/target-linux-x64/libcupti.so.12.3
if [[ ! -f "${LIBCUPTI123}" ]]; then
echo "2024.1.1 workaround expects to be running inside CUDA 12.3 container"
exit 1
fi
# Use libcupti.so.12.3 because this is a CUDA 12.3 container
ln -s "${LIBCUPTI123}" "${NSYS202411}/target-linux-x64/libcupti.so.12.3"
mv "${NSYS202411}/target-linux-x64/libcupti.so.12.4" "${NSYS202411}/target-linux-x64/_libcupti.so.12.4"
NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
if [[ -d "${NSYS202451}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi

# Install extra dependencies needed for `nsys recipe ...` commands. These are
Expand Down
253 changes: 253 additions & 0 deletions .github/container/jax-nccl-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
#!/usr/bin/env python
import argparse
from ctypes import byref, cdll, c_int, POINTER
from functools import partial
import jax
from jax.experimental.multihost_utils import sync_global_devices
from jax.experimental.shard_map import shard_map
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import os
import time


libcudart = cdll.LoadLibrary("libcudart.so")
cudaGetDeviceCount = libcudart.cudaGetDeviceCount
cudaGetDeviceCount.argtypes = [POINTER(c_int)]
cudaGetDeviceCount.restype = c_int
cudaProfilerStart = libcudart.cudaProfilerStart
cudaProfilerStop = libcudart.cudaProfilerStop


def visible_device_count() -> int:
"""
Query the number of local devices visible to this process.
"""
count = c_int()
assert cudaGetDeviceCount(byref(count)) == 0
return count.value


def int_or_env(value) -> int:
try:
return int(value)
except ValueError:
return int(os.environ[value])


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Pure-JAX implementation of a NCCL performance test"
)
parser.add_argument(
"--coordinator-address",
help="Distributed coordinator address:port; used if --distributed is passed.",
)
parser.add_argument(
"--distributed",
action="store_true",
help="Run jax.distributed.initialize()",
)
parser.add_argument(
"--gpus-per-process",
help=(
"Number of GPUs driven by each controller process. "
"Defaults to 1 with --distributed and all of them otherwise."
),
type=int,
)
parser.add_argument(
"--process-count",
help=(
"When --distributed is passed this gives the total number of processes. "
"This can either be an integer of the name of an environment variable."
),
type=int_or_env,
)
parser.add_argument(
"--process-id",
help=(
"When --distributed is passed this gives the global index of this process."
"This can either be an integer or the name of an environment variable."
),
type=int_or_env,
)
args = parser.parse_args()

assert (
args.process_id is None or args.distributed
), "--process-id is only relevant with --distributed"
if args.distributed:
null_args = {
args.coordinator_address is None,
args.gpus_per_process is None,
args.process_count is None,
args.process_id is None,
}
if all(null_args):
# Use default behaviour
jax.distributed.initialize()
else:
assert not any(null_args), (
"All of --coordinator-address, --gpus-per-process, --process-count and "
"--process-id must be passed if any of them are."
)
visible_devices = visible_device_count()
local_processes, rem = divmod(visible_devices, args.gpus_per_process)
assert rem == 0, (
f"--gpus-per-process={args.gpus_per_process} does not divide the "
"visible device count {visible_devices}"
)
# assume processes within a node are globally numbered contiguously
local_process_id = args.process_id % local_processes
first_local_device = local_process_id * args.gpus_per_process
local_device_ids = list(
range(first_local_device, first_local_device + args.gpus_per_process)
)
print(
f"Rank {args.process_id} has local rank {local_process_id} and "
f"devices {local_device_ids} from a total of {visible_devices} "
f"visible on this node, {args.process_count} processes and "
f"{args.process_count*args.gpus_per_process} total devices.",
flush=True,
)
jax.distributed.initialize(
coordinator_address=args.coordinator_address,
local_device_ids=local_device_ids,
num_processes=args.process_count,
process_id=args.process_id,
)
elif args.gpus_per_process is not None:
# Respect --gpus-per-process even without --distributed
jax.config.update(
"jax_cuda_visible_devices",
",".join(str(x) for x in range(args.gpus_per_process)),
)

if jax.process_index() == 0:
print(f"JAX devices: {jax.devices()}")
n_devices = jax.device_count()
assert (
args.gpus_per_process is None
or jax.local_device_count() == args.gpus_per_process
), (
f"Got {jax.local_device_count()} local devices despite "
f"--gpus-per-process={args.gpus_per_process}"
)
mesh = Mesh(jax.devices(), axis_names=("i",))
min_size_power = 0
max_size_power = 30
max_elements = 2**32
sharding = partial(
shard_map,
mesh=mesh,
in_specs=(P("i"), P("i", None), None),
check_rep=False,
out_specs=P("i"),
)

@partial(jax.jit, static_argnames="collective")
@sharding
def measure_collective(sync, big_input, collective):
with jax.named_scope(collective):
output = 1.0
big_input = big_input * jax.lax.psum(sync, "i")
assert big_input.shape == (1, 2**max_size_power), big_input.shape
for size in range(max_size_power + 1):
values_per_device = 2**size
input = output * jax.lax.slice(
big_input, (0, 0), (1, values_per_device)
)
assert input.shape == (1, values_per_device), input.shape
result = None
# Trigger the collective we want to measure
if collective == "all_gather":
if input.size * n_devices < max_elements:
result = jax.lax.all_gather(input, "i")
assert result.shape == (n_devices, *input.shape), result.shape
elif collective == "all_reduce":
if input.size < max_elements:
result = jax.lax.psum(input, "i")
assert result.shape == (1, values_per_device), result.shape
elif collective == "broadcast":
if input.size < max_elements:
# FIXME: need https://github.com/google/jax/pull/20705 re-land
result = jax.lax.pbroadcast(input, "i", 0)
assert result.shape == (1, values_per_device), result.shape
elif collective == "permute":
if input.size < max_elements:
# TODO: make this sensitive to whether the permutation does or
# does not cross NVLink domain boundaries
permutation = [
(i, (i + 1) % n_devices) for i in range(n_devices)
]
result = jax.lax.ppermute(input, "i", permutation)
assert result.shape == (1, values_per_device), result.shape
else:
assert collective == "reduce_scatter", collective
if values_per_device >= n_devices:
# Need to be able to scatter at least 1 value of the result on
# each device. This results in the largest message size (NCCL
# convention) for reduce-scatter being a factor `n_devices`
# smaller than the other collectives
result = jax.lax.psum_scatter(
input, "i", scatter_dimension=1, tiled=True
)
assert result.shape == (
1,
values_per_device // n_devices,
), result.shape
# Do something with the results to stop them getting combined/removed
if result is not None:
output *= 1.5 + jnp.tanh(jnp.mean(result)) # scale by [0.5, 1.5]
return jnp.array([output])

def measure(sync, input, host_timer=False):
for op in ["all_gather", "all_reduce", "permute", "reduce_scatter"]:
start = time.time()
result = measure_collective(sync, input, op)
if host_timer:
result.block_until_ready()
if jax.process_index() == 0:
print(f"First {op} duration {time.time()-start:.2f}s")
return result

def device_put_local(x: jax.Array):
return [jax.device_put(x, d) for d in jax.local_devices()]

# This helper is used to trigger a small barrier before the main measurement, again
# to improve measurement quality. It's always the same and is sharded with one
# value per device.
sync = jax.make_array_from_single_device_arrays(
(n_devices,),
NamedSharding(mesh, P("i")),
device_put_local(jnp.ones((1,))),
)
input = jax.make_array_from_single_device_arrays(
(n_devices, 2**max_size_power),
NamedSharding(mesh, P("i")),
device_put_local(jax.random.normal(jax.random.key(1), (1, 2**max_size_power))),
)
if jax.process_index() == 0:
print(f"Data for pre-measurement synchronisation {sync.shape}")
jax.debug.visualize_array_sharding(sync)
print(f"Data for collective measurements {input.shape}")
jax.debug.visualize_array_sharding(input)

start = time.time()
sync_global_devices("init")
sync_time = time.time() - start
if jax.process_index() == 0:
print(f"Barrier time (NCCL init): {sync_time:.2f}s")

measure(sync, input, host_timer=True)
sync_global_devices("warmup_done")
cudaProfilerStart()
sync_global_devices("profiling_started")
for _ in range(10):
measure(sync, input)
sync_global_devices("measurements_completed")
cudaProfilerStop()
sync_global_devices("profiling_ended")
if jax.process_index() == 0:
print("Exiting...")
5 changes: 4 additions & 1 deletion .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,10 @@
"source": [
"if len(steady_state.communication):\n",
" fig, grid = plt.subplots(\n",
" nrows=len(top_module_ids), figsize=[15, 5], squeeze=False, tight_layout=True\n",
" nrows=len(top_module_ids),\n",
" figsize=[15, 5 * len(top_module_ids)],\n",
" squeeze=False,\n",
" tight_layout=True,\n",
" )\n",
" time_df = steady_state.thunk.loc[\n",
" ~steady_state.thunk[\"Communication\"], (\"ProjStartMs\", \"ProjDurMs\")\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def align_profiler_data_timestamps(
# Apply these corrections to the device-side timestamps
for k in ["communication", "module", "thunk"]:
df = getattr(frames, k)
if df is None:
continue
df["ProjStartMs"] -= median_device_skews
setattr(frames, k, df)
return frames, {
Expand Down
Loading

0 comments on commit 99d059b

Please sign in to comment.