Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: add basic CI, support all-to-all and repeated thunks #877

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict, namedtuple\n",
"from collections import defaultdict\n",
"from jax_nsys import (\n",
" calculate_collective_metrics,\n",
" compile_protos,\n",
Expand All @@ -20,8 +20,9 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import sys"
"import pandas as pd # type: ignore\n",
"import sys\n",
"from typing import NamedTuple"
]
},
{
Expand Down Expand Up @@ -182,13 +183,19 @@
" )\n",
"\n",
"\n",
"def reduce_module_stats(module_stats):\n",
"class Summary(NamedTuple):\n",
" mean: float\n",
" std: float\n",
" total: float\n",
"\n",
"\n",
"def reduce_module_stats(module_stats) -> dict[str, Summary]:\n",
" # [{\"a\": 0.3}, {\"a\": 0.4}] -> {\"a\": (0.35, stddev), \"#Instances\": 2}\n",
" r = {\"#Instances\": len(module_stats)}\n",
" num_instances = len(module_stats)\n",
" r = {\"#Instances\": Summary(mean=num_instances, std=0.0, total=num_instances)}\n",
" keys = module_stats[0].keys()\n",
" for stats in module_stats[1:]:\n",
" assert stats.keys() == keys\n",
" Summary = namedtuple(\"Number\", [\"mean\", \"std\", \"total\"])\n",
" for k in keys:\n",
" values = [stats[k] for stats in module_stats]\n",
" r[k] = Summary(mean=np.mean(values), std=np.std(values), total=np.sum(values))\n",
Expand All @@ -197,21 +204,26 @@
"\n",
"# Aggregate HLO module statistics over repeated executions of them\n",
"agg_module_stats = [(k, reduce_module_stats(v)) for k, v in module_stats.items()]\n",
"sort_key = lambda x: x[1][\"GPU time [ms]\"].total\n",
"\n",
"\n",
"def sort_key(x):\n",
" return x[1][\"GPU time [ms]\"].total\n",
"\n",
"\n",
"agg_module_stats.sort(key=sort_key, reverse=True)\n",
"total = sum(sort_key(x) for x in agg_module_stats)\n",
"print(\" Active GPU time #Exec. #Thunks Module name\")\n",
"accounted_time, top_n = 0.0, None\n",
"for n, tup in enumerate(agg_module_stats):\n",
" module_name, module_stats = tup\n",
" module_name, stats = tup\n",
" module_time = sort_key(tup)\n",
" print(\n",
" \" {:7.2f}% {:9.2f}ms {:5} {:5.0f}±{:<3.0f} {}\".format(\n",
" 100.0 * module_time / total,\n",
" module_time,\n",
" module_stats[\"#Instances\"],\n",
" module_stats[\"#Thunks\"].mean,\n",
" module_stats[\"#Thunks\"].std,\n",
" stats[\"#Instances\"].mean,\n",
" stats[\"#Thunks\"].mean,\n",
" stats[\"#Thunks\"].std,\n",
" module_name,\n",
" )\n",
" )\n",
Expand Down Expand Up @@ -263,9 +275,9 @@
"\n",
"# Project the thunk runtime data onto some other data structures, to be\n",
"# presented in different ways.\n",
"op_runtime = defaultdict(float)\n",
"op_name_runtime = defaultdict(float)\n",
"src_runtime = defaultdict(float)\n",
"op_runtime: dict[str, float] = defaultdict(float)\n",
"op_name_runtime: dict[tuple[str, ...], float] = defaultdict(float)\n",
"src_runtime: dict[tuple[str, ...], float] = defaultdict(float)\n",
"\n",
"# Dummy entries to massage the source code view\n",
"gpu_active = [\"[GPU active]\"]\n",
Expand Down Expand Up @@ -304,10 +316,15 @@
" for called_comp_id in hlo_inst.called_computation_ids\n",
" for called_inst in hlo_module.find_computation(called_comp_id).instructions\n",
" ]\n",
" src_runtime_preferences = [set(), set(), [tuple(gpu_active_unknown)]]\n",
" op_name_runtime_preferences = [set(), [tuple(gpu_active_unknown)]]\n",
" non_empty_stack_traces = set()\n",
" non_empty_op_names = set()\n",
" src_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (\n",
" set(),\n",
" set(),\n",
" {tuple(gpu_active_unknown)},\n",
" )\n",
" op_name_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (\n",
" set(),\n",
" {tuple(gpu_active_unknown)},\n",
" )\n",
" for inst in [hlo_inst] + called_instructions:\n",
" frames = hlo_module.get_stack_frames(inst.metadata.stack_frame_id)\n",
" op_name = [inst.metadata.op_name] if len(inst.metadata.op_name) else []\n",
Expand Down Expand Up @@ -413,7 +430,9 @@
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby([\"ProgramId\", \"Name\", \"ModuleExecution\"])\n",
" collective_df = df.groupby(\n",
" [\"ProgramId\", \"Name\", \"ModuleExecution\", \"ThunkExecution\"]\n",
" )\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
Expand Down
12 changes: 9 additions & 3 deletions .github/container/jax_nsys/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# The expectation is that those archives will be copied and extracted on a
# laptop or workstation, and this installation script will be run there, while
# the `nsys-jax` wrapper is executed on a remote GPU cluster.
set -ex
SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
VIRTUALENV="${SCRIPT_DIR}/nsys_jax_venv"
if [[ ! -d "${VIRTUALENV}" ]]; then
Expand All @@ -18,12 +19,17 @@ if [[ ! -d "${VIRTUALENV}" ]]; then
. "${VIRTUALENV}/bin/activate"
python -m pip install -U pip
"${SCRIPT_DIR}/nsys-jax-ensure-protobuf"
python -m pip install jupyterlab
# matplotlib is a dependency of Analysis.ipynb but not jax_nsys
python -m pip install jupyterlab matplotlib
python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys"
curl -o "${VIRTUALENV}/bin/flamegraph.pl" https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl
chmod 755 "${VIRTUALENV}/bin/flamegraph.pl"
else
echo "Virtual environment already exists, not installing anything..."
fi
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
if [ -z ${NSYS_JAX_INSTALL_SKIP_LAUNCH+x} ]; then
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
else
echo "Skipping launch of jupyterlab due to NSYS_JAX_INSTALL_SKIP_LAUNCH"
fi
15 changes: 11 additions & 4 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pandas as _pd

_pd.options.mode.copy_on_write = True

from .analysis import calculate_collective_metrics, generate_compilation_statistics
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos
from .utils import remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
"calculate_collective_metrics",
"compile_protos",
"create_flamegraph",
"display_flamegraph",
"generate_compilation_statistics",
"load_profiler_data",
"remove_child_ranges",
"xla_module_metadata",
]
89 changes: 54 additions & 35 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import functools
import math
import numpy as np
import pandas as pd
import pandas as pd # type: ignore

from .protobuf import xla_module_metadata
from .utils import make_child_mask

pd.options.mode.copy_on_write = True

def element_type_in_bits(element_type: int) -> int:

def element_type_width(element_type: int) -> int:
"""
Given an int representing an XLA PrimitiveType enum value, return the width of that
type in bits.
Expand All @@ -29,8 +31,35 @@ def element_type_in_bits(element_type: int) -> int:
raise Exception(f"Could not deduce size of {enum_name}")


def _collective_correction(kind: str, size: int) -> tuple[float, float]:
"""
Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
"""
match kind:
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
case "all-gather":
return (size, (size - 1) / size)
case "all-reduce":
return (1, 2 * (size - 1) / size)
case "all-to-all":
# https://github.com/NVIDIA/nccl-tests/blob/a1efb427e764241bc43d2d91be875c9f55da03a5/src/alltoall.cu#L44
return (1, (size - 1) / size)
case "collective-broadcast":
return (1, 1)
case "collective-permute":
return (1, 1)
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
case "reduce-scatter":
return (size, (size - 1) / size)
case _:
assert False, f"Unknown collective kind {kind}"


@functools.lru_cache
def get_message_size(program_id: int, instruction: str) -> int:
def get_message_size(program_id: int, instruction: str) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
Expand All @@ -40,56 +69,46 @@ def get_message_size(program_id: int, instruction: str) -> int:
"""
module_proto = xla_module_metadata(program_id)
_, inst = module_proto.find_instruction(instruction)
assert inst.opcode in {
"all-gather-start",
"all-reduce-start",
"collective-broadcast",
"collective-permute-start",
"reduce-scatter",
}, f"{instruction}: message size calculation for {inst.opcode} has not yet been validated"
assert (
inst.opcode
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {inst.opcode} has not yet been validated"
if inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
collective_sizes = tuple(
{len(group.replica_ids) for group in inst.replica_groups}
)
assert (
len(collective_sizes) == 1
collective_size = len(inst.replica_groups[0].replica_ids)
assert all(
len(group.replica_ids) == collective_size for group in inst.replica_groups
), f"Heterogeneous collective {inst.replica_groups} could not be interpreted"
collective_size = collective_sizes[0]
total_msg_size = 0
for operand_id in inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.shape.dimensions,
start=element_type_in_bits(operand.shape.element_type),
start=element_type_width(operand.shape.element_type),
)
if inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
assert msg_size_bits % collective_size == 0
msg_size_bits //= collective_size
assert msg_size_bits % 8 == 0
total_msg_size += msg_size_bits // 8
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

# Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
collective = inst.opcode.removesuffix("-start")
bw_correction, bus_correction = {
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
"all-gather": (collective_size, (collective_size - 1) / collective_size),
"all-reduce": (1, 2 * (collective_size - 1) / collective_size),
"collective-broadcast": (1, 1),
"collective-permute": (1, 1),
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
"reduce-scatter": (collective_size, (collective_size - 1) / collective_size),
}[collective]
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return pd.Series(
[total_msg_size, collective, collective_size, bw_correction, bus_correction],
index=[
Expand Down Expand Up @@ -153,7 +172,7 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
main_thread = main_thread[0]

# Aggregate compilation stats in here
compile_time_ns = defaultdict(lambda: np.zeros(2))
compile_time_ns: dict[str, np.ndarray] = defaultdict(lambda: np.zeros(2))

# Identify the ranges in the main thread that represent parallel compilation, i.e.
# ranges whose child ranges are in different threads.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import lzma
import numpy as np
import pandas as pd
import pandas as pd # type: ignore
import pathlib
import re

from .protobuf import xla_module_metadata
from .utils import make_child_mask

pd.options.mode.copy_on_write = True


def _classify_comms(thunk_df: pd.DataFrame, prefix: pathlib.Path) -> pd.DataFrame:
# Classify each thunk as either communication or computation, as we only
Expand Down Expand Up @@ -245,6 +247,14 @@ def clean_data_frame(d, extra_columns=[]):
value=r"\2",
regex=True,
)
# Add a new column describing which (0th, 1st, ...) execution of the thunk
# within the given module execution this is. For example, while loops in the
# HLO can lead to the same thunk being executed multiple times within the same
# module execution.
thunk_df["ThunkExecution"] = thunk_df.groupby(
["TID", "ProgramId", "Name", "ModuleExecution"]
).cumcount()

# Classify thunks as communication/computation and save to output
output["thunk"] = _classify_comms(thunk_df, prefix)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# WARNING: it is tacitly assumed that the protobuf compiler (protoc) is
# compatible with the google.protobuf version.
import glob
import google.protobuf
import os
import pathlib
import shutil
import subprocess
import sys
from typing import Optional


def which(executable: str) -> pathlib.Path:
Expand All @@ -28,7 +28,11 @@ def which(executable: str) -> pathlib.Path:
return pathlib.Path(exe)


def compile_protos(proto_dir: str | pathlib.Path, output_dir: str | pathlib.Path):
def compile_protos(
proto_dir: str | pathlib.Path,
output_dir: str | pathlib.Path,
output_stub_dir: Optional[str | pathlib.Path] = None,
):
if not os.path.isdir(proto_dir):
raise Exception(f"Input: {proto_dir} is not a directory")
if not os.path.isdir(output_dir):
Expand All @@ -39,6 +43,12 @@ def compile_protos(proto_dir: str | pathlib.Path, output_dir: str | pathlib.Path
raise Exception(f"Did not find any .proto files under {proto_dir}")
protoc = which("protoc")
# Generate code to load the protobuf files
args: list[str | pathlib.Path] = [protoc, f"-I={proto_dir}", f"--python_out={output_dir}"]
args: list[str | pathlib.Path] = [
protoc,
f"-I={proto_dir}",
f"--python_out={output_dir}",
]
if output_stub_dir is not None:
args.append(f"--pyi_out={output_stub_dir}")
args += proto_files
subprocess.run(args, check=True)
4 changes: 3 additions & 1 deletion .github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pandas as pd # type: ignore
from typing import Optional

pd.options.mode.copy_on_write = True


def make_child_mask(df: pd.DataFrame, parent_row: int) -> pd.Series:
"""
Expand Down
Loading
Loading