Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge main into protected branch 24.10-devel #1062

Closed
wants to merge 11 commits into from
4 changes: 2 additions & 2 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04
ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04
ARG GIT_USER_NAME="JAX Toolbox"
ARG GIT_USER_EMAIL=jax@nvidia.com
ARG CLANG_VERSION=17
ARG CLANG_VERSION=18

###############################################################################
## Obtain GCP's NCCL TCPx plugin
Expand Down
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
export NVTE_BUILD_THREADS_PER_JOB=8
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF
Expand Down
3 changes: 3 additions & 0 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
chmod 755 $BUILD_PATH_JAXLIB/*

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
11 changes: 6 additions & 5 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ apt-get clean

rm -rf /var/lib/apt/lists/*

NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
if [[ -d "${NSYS202451}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
if [[ -d "${NSYS}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
done

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
Expand Down
54 changes: 46 additions & 8 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
from typing import Any

from .protobuf import HloProto, xla_module_metadata
from .protobuf import HloProto, _host_memory_space, xla_module_metadata
from .utils import make_child_mask, ProfilerData

pd.options.mode.copy_on_write = True
Expand Down Expand Up @@ -38,6 +38,11 @@ def align_profiler_data_timestamps(
# Determine which collective size will be used for the alignment
num_profiled_devices = len(comm_df.index.get_level_values("Device").unique())
max_collective_size = comm_df["CollectiveSize"].max()
if max_collective_size == 1:
print(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
Expand Down Expand Up @@ -193,13 +198,51 @@ def _get_message_size(
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"

def _byte_size(inst) -> int:
size_bits = math.prod(
inst.shape.dimensions,
start=element_type_width(inst.shape.element_type),
)
size_bytes, rem = divmod(size_bits, 8)
assert rem == 0
return size_bytes

if comm_inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}:
# Label host-device transfers orchestrated by dynamic[-update]-slice as single
# device collectives.
collective_size = 1
if comm_inst.opcode == "dynamic-update-slice":
# For dynamic-update-slice the second operand is the one being copied
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1])
transfer_size = _byte_size(src_inst.proto())
else:
# For dynamic-slice the return type size is the transfer size
assert comm_inst.opcode == "dynamic-slice"
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0])
transfer_size = _byte_size(comm_inst)
dest_on_host = _host_memory_space(comm_inst)
src_on_host = _host_memory_space(src_inst.proto())
assert src_on_host != dest_on_host, (
'dynamic[-update]-slice is only considered is only "communication" if it '
"represents a host-device transfer"
)
return (
transfer_size,
"device-to-host" if dest_on_host else "host-to-device",
1, # collective size
1.0, # bw_correction
1.0, # bus_correction
)
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
Expand All @@ -220,17 +263,12 @@ def _get_message_size(
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.proto().shape.dimensions,
start=element_type_width(operand.proto().shape.element_type),
)
msg_size_bytes = _byte_size(operand.proto())
if comm_inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
msg_size_bytes, rem = divmod(msg_size_bytes, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

collective = comm_inst.opcode.removesuffix("-start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def is_communication(row):
return _calculate_overlap(thunk_df)


compile_prefix = "XlaCompile:#module="


def _load_nvtx_gpu_proj_trace_single(
prefix: pathlib.Path,
file: pathlib.Path,
Expand Down Expand Up @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single(
unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
# If the profile only includes N>1 modules, we may still be able to identify the
# main thread as the one responsible for XlaCompile ranges projected onto the GPU
# timeline
compile_ranges = df.loc[~all_thunks, "Name"].str.startswith(
tsl_prefix + compile_prefix
)
compile_range_ids = compile_ranges[compile_ranges].index
unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
assert len(main_pid_tid_candidates) < 2
if len(main_pid_tid_candidates) == 1:
# Possibly not correct if len(device_by_pid_tid) > 1
assert len(device_by_pid_tid) > 0
# Associate the main thread with the 0th device in device_by_pid_tid
main_thread_df = device_by_pid_tid.iloc[:1]
main_thread_df.index = pd.MultiIndex.from_tuples(
main_pid_tid_candidates, names=["PID", "TID"]
Expand Down Expand Up @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace(
return output


compile_prefix = "TSL:XlaCompile:#module="


def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame:
# When parallel compilation is enabled, we end up with worker threads that
# emit NVTX ranges but which are not accounted for in the RangeStack tree.
# Splice these in under the relevant XlaCompile ranges in the RangeStack tree and
# drop everything else.
retain_mask = pd.Series(False, index=compile_df.index)
compile_mask = compile_df["Name"].str.startswith(compile_prefix)
compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix)
for compile_range in compile_df[compile_mask].itertuples():
# Identify the slice of `compile_df` that overlaps in time with this XlaCompile
# range
Expand Down
90 changes: 71 additions & 19 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
import functools
import lzma
import pathlib
import typing


def _host_memory_space(inst):
return inst.shape.layout.memory_space == 5


class StackFrame(typing.NamedTuple):
column: int
file: str
Expand All @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
# proto representing the actual collective, which will be different if the
# async launch is handled by an async-start op
# TODO: can any of copy-start, custom-call, recv, send represent communication?
# This also aims to identify, and (for now) flag as communication, kernels that
# implement device-to-host and host-to-device copies for memory offloading.
# For example, a device-to-host offload might look like
# computation {
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
# }
# async_computation {
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) annotation shows that a buffer is in host memory.
# A host-to-device load might look like
# computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
# }
# async_computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) memory space annotation is in a parameter instead of in the
# return value.
# For now, handling host-device kernels as single-device "collective"
# communication should be sufficient.
self._comm_proto = None
comm_opcodes = {
"all-gather",
Expand All @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
"all-reduce-start",
"collective-permute-start",
}

def _is_offloading_instruction(inst):
host_dest = _host_memory_space(inst)

def _host_operand(i):
_, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i])
return _host_memory_space(op.proto())

if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0):
return True
elif (
inst.opcode == "dynamic-update-slice"
and host_dest == _host_operand(0)
and host_dest != _host_operand(1)
):
return True
return False

if self._proto.opcode in comm_opcodes | comm_start_opcodes:
self._comm_proto = self._proto
elif self._proto.opcode == "async-start":
elif self._proto.opcode in {"async-start", "fusion"}:
# fusion example:
# computation {
# param_0 = f32[...]{...:S(5)} parameter(0)
# ...
# ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
# }
# inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
# This might be thinly wrapping an opcode in `comm_opcodes`
other_opcodes = defaultdict(int)
for called_id in self._proto.called_computation_ids:
for called_inst in wrapped_hlo_proto.find_computation(
called_id
).instructions:
if called_inst.opcode in comm_opcodes:
def _visit_computation(computation_id):
computation = wrapped_hlo_proto.find_computation(computation_id)
for called_inst in computation.instructions:
for called_id in called_inst.called_computation_ids:
_visit_computation(called_id)
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
called_inst
):
assert (
self._comm_proto is None
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
self._comm_proto = called_inst
else:
other_opcodes[called_inst.opcode] += 1
assert (
other_opcodes.keys() == {"parameter"}
), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}"

for called_id in self._proto.called_computation_ids:
_visit_computation(called_id)

def communication_proto(self):
return self._comm_proto
Expand All @@ -68,12 +125,7 @@ def is_communication(self) -> bool:
a little more complicated than you might hope, because async communications are
not handled uniformly.
"""
if self._comm_proto is None:
return False
assert (
self._comm_proto.channel_id != 0
), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}"
return True
return self._comm_proto is not None

def proto(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,8 @@ orbax-checkpoint:
tracking_ref: main
latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f
mode: pip-vcs
pathwaysutils:
url: https://github.com/google/pathways-utils.git
tracking_ref: main
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
mode: pip-vcs
5 changes: 2 additions & 3 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
Expand All @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization}
Expand Down
Loading
Loading